diff --git a/modules/text_generation.py b/modules/text_generation.py index 75294013..4f324239 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -138,9 +138,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) - if not add_bos_token: - while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id: - input_ids = input_ids[:, 1:] + + if hasattr(shared.tokenizer, 'bos_token_id'): + if add_bos_token: + if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0: + # Add a missing bos token (it may not have been added due to faulty model metadata) + bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]]) + input_ids = torch.cat((bos_tensor, input_ids), 1) + + # Prevent double bos token due to jinja templates with somewhere + while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id: + input_ids = input_ids[:, 1:] + else: + # Remove any bos token that may have been added + while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id: + input_ids = input_ids[:, 1:] # Handling truncation if truncation_length is not None: