diff --git a/modules/exllama.py b/modules/exllama.py index 8b1222f9..6355b60f 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -51,21 +51,7 @@ class ExllamaModel: self.generator = generator return result, result - def generate(self, prompt, state, callback=None): - self.generator.settings.temperature = state['temperature'] - self.generator.settings.top_p = state['top_p'] - self.generator.settings.top_k = state['top_k'] - self.generator.settings.typical = state['typical_p'] - self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] - if state['ban_eos_token']: - self.generator.disallow_tokens([self.tokenizer.eos_token_id]) - else: - self.generator.disallow_tokens(None) - - text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) - return text - - def generate_with_streaming(self, prompt, state, callback=None): + def generate_with_streaming(self, prompt, state): self.generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k'] @@ -86,5 +72,12 @@ class ExllamaModel: if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: break + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output + def encode(self, string, **kwargs): return self.tokenizer.encode(string)