Fix exllama tokenizers (#3954)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
saltacc 2023-09-16 12:42:38 +00:00 committed by GitHub
parent 8d85425e09
commit ed6b6411fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 10 deletions

View File

@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import version as torch_version from torch import version as torch_version
@ -111,7 +112,7 @@ class ExllamaModel:
if state['custom_token_bans']: if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')] to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0: if len(to_ban) > 0:
self.generator.disallow_tokens(self.tokenizer, to_ban) self.generator.disallow_tokens(to_ban)
# Case 1: no CFG # Case 1: no CFG
if state['guidance_scale'] == 1: if state['guidance_scale'] == 1:
@ -119,6 +120,11 @@ class ExllamaModel:
# Tokenizing the input # Tokenizing the input
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len) ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
if state['add_bos_token']:
ids = torch.cat(
[torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device),
ids], dim=1
).to(torch.int64)
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']: if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1] max_new_tokens = state['truncation_length'] - ids.shape[-1]
@ -148,7 +154,12 @@ class ExllamaModel:
alpha = state['guidance_scale'] alpha = state['guidance_scale']
prompts = [prompt, state['negative_prompt'] or ''] prompts = [prompt, state['negative_prompt'] or '']
ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len) ids, mask = self.tokenizer.encode(
prompts,
return_mask=True,
max_seq_len=self.model.config.max_seq_len,
add_bos=state['add_bos_token']
)
if state['auto_max_new_tokens']: if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids[0].shape[-1] max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
else: else:
@ -188,7 +199,12 @@ class ExllamaModel:
return output return output
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len) return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
def decode(self, string, **kwargs): def decode(self, ids, **kwargs):
return self.tokenizer.decode(string)[0] if isinstance(ids, int):
ids = torch.tensor([[ids]])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]

View File

@ -48,7 +48,7 @@ class Exllamav2Model:
result.cache = cache result.cache = cache
result.tokenizer = tokenizer result.tokenizer = tokenizer
result.generator = generator result.generator = generator
return result, tokenizer return result, result
def generate_with_streaming(self, prompt, state): def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings() settings = ExLlamaV2Sampler.Settings()
@ -65,7 +65,7 @@ class Exllamav2Model:
if len(to_ban) > 0: if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban) settings.disallow_tokens(self.tokenizer, to_ban)
ids = self.tokenizer.encode(prompt) ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1] initial_len = ids.shape[-1]
@ -104,7 +104,12 @@ class Exllamav2Model:
return output return output
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string) return self.tokenizer.encode(string, add_bos=True)
def decode(self, string, **kwargs): def decode(self, ids, **kwargs):
return self.tokenizer.decode(string)[0] if isinstance(ids, int):
ids = torch.tensor([[ids]])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]