diff --git a/modules/text_generation.py b/modules/text_generation.py index 6ee9d931..f5d2b8d0 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -37,9 +37,13 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): return input_ids.cuda() def decode(output_ids): - reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - return reply + # Open Assistant relies on special tokens like <|endoftext|> + if re.match('oasst-*', shared.model_name.lower()): + return shared.tokenizer.decode(output_ids, skip_special_tokens=False) + else: + reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + return reply def generate_softprompt_input_tensors(input_ids): inputs_embeds = shared.model.transformer.wte(input_ids)