From 1c74b3ab451ec72b2a2a40ec493fcb6b7b1afba5 Mon Sep 17 00:00:00 2001 From: Yiximail Date: Fri, 8 Dec 2023 20:50:53 +0800 Subject: [PATCH] Fix partial unicode characters issue (#4837) --- modules/exllama.py | 19 ++++++++++++++++++- modules/exllamav2.py | 10 +++++++++- modules/text_generation.py | 7 ++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index 4257ee07..25c4c99d 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -165,10 +165,19 @@ class ExllamaModel: if has_leading_space: decoded_text = ' ' + decoded_text - yield decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue + if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: break + yield decoded_text + # Case 2: CFG # Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py else: @@ -205,6 +214,14 @@ class ExllamaModel: if has_leading_space: decoded_text = ' ' + decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue + yield decoded_text if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: break diff --git a/modules/exllamav2.py b/modules/exllamav2.py index b92e8840..d755a36a 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -138,11 +138,19 @@ class Exllamav2Model: if has_leading_space: decoded_text = ' ' + decoded_text - yield decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: break + yield decoded_text + def generate(self, prompt, state): output = '' for output in self.generate_with_streaming(prompt, state): diff --git a/modules/text_generation.py b/modules/text_generation.py index 4cf4f720..417ac194 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -362,7 +362,12 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if output[-1] in eos_token_ids: break - cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from) + new_content = get_reply_from_output_ids(output, state, starting_from=starting_from) + # check the partial unicode character + if chr(0xfffd) in new_content: + continue + + cumulative_reply += new_content starting_from = len(output) yield cumulative_reply