From b9e0712b92ab81eee50740253798d90ed835a43a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 12 Mar 2023 23:58:25 -0300 Subject: [PATCH] Fix Open Assistant --- modules/text_generation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 6ee9d931..f5d2b8d0 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -37,9 +37,13 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): return input_ids.cuda() def decode(output_ids): - reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - return reply + # Open Assistant relies on special tokens like <|endoftext|> + if re.match('oasst-*', shared.model_name.lower()): + return shared.tokenizer.decode(output_ids, skip_special_tokens=False) + else: + reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + return reply def generate_softprompt_input_tensors(input_ids): inputs_embeds = shared.model.transformer.wte(input_ids)