From b15f51015477c9709e2dff616c20466e9b3dc727 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 20 Dec 2023 07:31:42 -0800 Subject: [PATCH] Optimize ExLlamav2 (non-HF) loader --- modules/exllamav2.py | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 2cf4a039..3a6b231a 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -1,4 +1,3 @@ -import random import traceback from pathlib import Path @@ -10,7 +9,7 @@ from exllamav2 import ( ExLlamaV2Config, ExLlamaV2Tokenizer ) -from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler +from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator from modules import shared from modules.logging_colors import logger @@ -64,7 +63,7 @@ class Exllamav2Model: else: cache = ExLlamaV2Cache(model) - generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) + generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) result = self() result.model = model @@ -115,41 +114,22 @@ class Exllamav2Model: ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = ids[:, -get_max_prompt_length(state):] - initial_len = ids.shape[-1] if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - ids.shape[-1] else: max_new_tokens = state['max_new_tokens'] - # _gen_begin_base - self.cache.current_seq_len = 0 - self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras) + self.generator.set_stop_conditions([]) + self.generator.begin_stream(ids, settings, loras=self.loras) - has_leading_space = False + decoded_text = '' for i in range(max_new_tokens): - logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu() - token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer) - ids = torch.cat([ids, token], dim=1) - - if i == 0 and self.tokenizer.tokenizer.id_to_piece(int(token)).startswith('▁'): - has_leading_space = True - - decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0] - if has_leading_space: - decoded_text = ' ' + decoded_text - - # Check the partial unicode character - if chr(0xfffd) in decoded_text: - is_last = i == max_new_tokens - 1 - is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything - # If we are not at the end of the generation, we skip this token - if not (is_last or is_stopping): - continue - - if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: + chunk, eos, _ = self.generator.stream() + if eos or shared.stop_everything: break + decoded_text += chunk yield decoded_text def generate(self, prompt, state):