From 766c760cd778d4b96aedcfe70b4ee21420c0d6d3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 17 Jun 2023 18:00:10 -0300 Subject: [PATCH] Use gen_begin_reuse in exllama --- modules/exllama.py | 50 +++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index b8fcc4af..8b1222f9 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -41,45 +41,49 @@ class ExllamaModel: model = ExLlama(config) tokenizer = ExLlamaTokenizer(str(tokenizer_model_path)) cache = ExLlamaCache(model) + generator = ExLlamaGenerator(model, tokenizer, cache) result = self() result.config = config result.model = model result.cache = cache result.tokenizer = tokenizer + self.generator = generator return result, result def generate(self, prompt, state, callback=None): - generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) - generator.settings.temperature = state['temperature'] - generator.settings.top_p = state['top_p'] - generator.settings.top_k = state['top_k'] - generator.settings.typical = state['typical_p'] - generator.settings.token_repetition_penalty_max = state['repetition_penalty'] + self.generator.settings.temperature = state['temperature'] + self.generator.settings.top_p = state['top_p'] + self.generator.settings.top_k = state['top_k'] + self.generator.settings.typical = state['typical_p'] + self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] if state['ban_eos_token']: - generator.disallow_tokens([self.tokenizer.eos_token_id]) + self.generator.disallow_tokens([self.tokenizer.eos_token_id]) + else: + self.generator.disallow_tokens(None) - text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) + text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) return text def generate_with_streaming(self, prompt, state, callback=None): - generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) - generator.settings.temperature = state['temperature'] - generator.settings.top_p = state['top_p'] - generator.settings.top_k = state['top_k'] - generator.settings.typical = state['typical_p'] - generator.settings.token_repetition_penalty_max = state['repetition_penalty'] + self.generator.settings.temperature = state['temperature'] + self.generator.settings.top_p = state['top_p'] + self.generator.settings.top_k = state['top_k'] + self.generator.settings.typical = state['typical_p'] + self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] if state['ban_eos_token']: - generator.disallow_tokens([self.tokenizer.eos_token_id]) + self.generator.disallow_tokens([self.tokenizer.eos_token_id]) + else: + self.generator.disallow_tokens(None) - generator.end_beam_search() - ids = generator.tokenizer.encode(prompt) - generator.gen_begin(ids) - initial_len = generator.sequence[0].shape[0] - for i in range(state['max_new_tokens']): - token = generator.gen_single_token() - yield (generator.tokenizer.decode(generator.sequence[0][initial_len:])) - if token.item() == generator.tokenizer.eos_token_id or shared.stop_everything: + self.generator.end_beam_search() + ids = self.generator.tokenizer.encode(prompt) + self.generator.gen_begin_reuse(ids) + initial_len = self.generator.sequence[0].shape[0] + for _ in range(state['max_new_tokens']): + token = self.generator.gen_single_token() + yield (self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])) + if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: break def encode(self, string, **kwargs):