From ed6b6411fba9b394836ff593465607e362a56276 Mon Sep 17 00:00:00 2001 From: saltacc Date: Sat, 16 Sep 2023 12:42:38 +0000 Subject: [PATCH] Fix exllama tokenizers (#3954) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/exllama.py | 26 +++++++++++++++++++++----- modules/exllamav2.py | 15 ++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index 177f028f..f3894b7a 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -1,5 +1,6 @@ from pathlib import Path +import torch import torch.nn.functional as F from torch import version as torch_version @@ -111,7 +112,7 @@ class ExllamaModel: if state['custom_token_bans']: to_ban = [int(x) for x in state['custom_token_bans'].split(',')] if len(to_ban) > 0: - self.generator.disallow_tokens(self.tokenizer, to_ban) + self.generator.disallow_tokens(to_ban) # Case 1: no CFG if state['guidance_scale'] == 1: @@ -119,6 +120,11 @@ class ExllamaModel: # Tokenizing the input ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len) + if state['add_bos_token']: + ids = torch.cat( + [torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device), + ids], dim=1 + ).to(torch.int64) ids = ids[:, -get_max_prompt_length(state):] if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - ids.shape[-1] @@ -148,7 +154,12 @@ class ExllamaModel: alpha = state['guidance_scale'] prompts = [prompt, state['negative_prompt'] or ''] - ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len) + ids, mask = self.tokenizer.encode( + prompts, + return_mask=True, + max_seq_len=self.model.config.max_seq_len, + add_bos=state['add_bos_token'] + ) if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - ids[0].shape[-1] else: @@ -188,7 +199,12 @@ class ExllamaModel: return output def encode(self, string, **kwargs): - return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len) + return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True) - def decode(self, string, **kwargs): - return self.tokenizer.decode(string)[0] + def decode(self, ids, **kwargs): + if isinstance(ids, int): + ids = torch.tensor([[ids]]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0] diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a325a4d3..0bfe1f73 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -48,7 +48,7 @@ class Exllamav2Model: result.cache = cache result.tokenizer = tokenizer result.generator = generator - return result, tokenizer + return result, result def generate_with_streaming(self, prompt, state): settings = ExLlamaV2Sampler.Settings() @@ -65,7 +65,7 @@ class Exllamav2Model: if len(to_ban) > 0: settings.disallow_tokens(self.tokenizer, to_ban) - ids = self.tokenizer.encode(prompt) + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token']) ids = ids[:, -get_max_prompt_length(state):] initial_len = ids.shape[-1] @@ -104,7 +104,12 @@ class Exllamav2Model: return output def encode(self, string, **kwargs): - return self.tokenizer.encode(string) + return self.tokenizer.encode(string, add_bos=True) - def decode(self, string, **kwargs): - return self.tokenizer.decode(string)[0] + def decode(self, ids, **kwargs): + if isinstance(ids, int): + ids = torch.tensor([[ids]]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0]