mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Fix ExLlama truncation
This commit is contained in:
parent
ee964bcce9
commit
ef17da70af
@ -111,7 +111,7 @@ class ExllamaModel:
|
|||||||
self.generator.end_beam_search()
|
self.generator.end_beam_search()
|
||||||
|
|
||||||
# Tokenizing the input
|
# Tokenizing the input
|
||||||
ids = self.generator.tokenizer.encode(prompt)
|
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||||
@ -141,7 +141,7 @@ class ExllamaModel:
|
|||||||
alpha = state['guidance_scale']
|
alpha = state['guidance_scale']
|
||||||
prompts = [prompt, state['negative_prompt'] or '']
|
prompts = [prompt, state['negative_prompt'] or '']
|
||||||
|
|
||||||
ids, mask = self.tokenizer.encode(prompts, return_mask=True)
|
ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len)
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
||||||
else:
|
else:
|
||||||
@ -181,7 +181,7 @@ class ExllamaModel:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
return self.tokenizer.encode(string)
|
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len)
|
||||||
|
|
||||||
def decode(self, string, **kwargs):
|
def decode(self, string, **kwargs):
|
||||||
return self.tokenizer.decode(string)[0]
|
return self.tokenizer.decode(string)[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user