Fix ExLlama truncation

This commit is contained in:
oobabooga 2023-08-20 08:50:32 -07:00
parent ee964bcce9
commit ef17da70af

View File

@ -111,7 +111,7 @@ class ExllamaModel:
self.generator.end_beam_search() self.generator.end_beam_search()
# Tokenizing the input # Tokenizing the input
ids = self.generator.tokenizer.encode(prompt) ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']: if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1] max_new_tokens = state['truncation_length'] - ids.shape[-1]
@ -141,7 +141,7 @@ class ExllamaModel:
alpha = state['guidance_scale'] alpha = state['guidance_scale']
prompts = [prompt, state['negative_prompt'] or ''] prompts = [prompt, state['negative_prompt'] or '']
ids, mask = self.tokenizer.encode(prompts, return_mask=True) ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len)
if state['auto_max_new_tokens']: if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids[0].shape[-1] max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
else: else:
@ -181,7 +181,7 @@ class ExllamaModel:
return output return output
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string) return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len)
def decode(self, string, **kwargs): def decode(self, string, **kwargs):
return self.tokenizer.decode(string)[0] return self.tokenizer.decode(string)[0]