Fix repeated tokens with exllama

This commit is contained in:
oobabooga 2023-06-17 19:02:08 -03:00
parent 766c760cd7
commit cbd63eeeff

View File

@ -51,21 +51,7 @@ class ExllamaModel:
self.generator = generator self.generator = generator
return result, result return result, result
def generate(self, prompt, state, callback=None): def generate_with_streaming(self, prompt, state):
self.generator.settings.temperature = state['temperature']
self.generator.settings.top_p = state['top_p']
self.generator.settings.top_k = state['top_k']
self.generator.settings.typical = state['typical_p']
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
if state['ban_eos_token']:
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
else:
self.generator.disallow_tokens(None)
text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
return text
def generate_with_streaming(self, prompt, state, callback=None):
self.generator.settings.temperature = state['temperature'] self.generator.settings.temperature = state['temperature']
self.generator.settings.top_p = state['top_p'] self.generator.settings.top_p = state['top_p']
self.generator.settings.top_k = state['top_k'] self.generator.settings.top_k = state['top_k']
@ -86,5 +72,12 @@ class ExllamaModel:
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break break
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass
return output
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string) return self.tokenizer.encode(string)