From cbd63eeeff246a13f9a943ba7067bd1bd2b1012e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 17 Jun 2023 19:02:08 -0300 Subject: [PATCH] Fix repeated tokens with exllama --- modules/exllama.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index 8b1222f9..6355b60f 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -51,21 +51,7 @@ class ExllamaModel: self.generator = generator return result, result - def generate(self, prompt, state, callback=None): - self.generator.settings.temperature = state['temperature'] - self.generator.settings.top_p = state['top_p'] - self.generator.settings.top_k = state['top_k'] - self.generator.settings.typical = state['typical_p'] - self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] - if state['ban_eos_token']: - self.generator.disallow_tokens([self.tokenizer.eos_token_id]) - else: - self.generator.disallow_tokens(None) - - text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) - return text - - def generate_with_streaming(self, prompt, state, callback=None): + def generate_with_streaming(self, prompt, state): self.generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k'] @@ -86,5 +72,12 @@ class ExllamaModel: if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: break + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output + def encode(self, string, **kwargs): return self.tokenizer.encode(string)