diff --git a/modules/text_generation.py b/modules/text_generation.py index 7ee1225b..7bc5d25d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -107,8 +107,10 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i else: new_tokens = len(output_ids) - len(input_ids[0]) reply = decode(output_ids[-new_tokens:], state['skip_special_tokens']) - if type(shared.tokenizer) is transformers.LlamaTokenizer: - if len(original_question) > 0 and original_question[-1] not in [' ', '\n']: + + # Prevent LlamaTokenizer from skipping a space + if type(shared.tokenizer) is transformers.LlamaTokenizer and len(output_ids) > 0: + if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'): reply = ' ' + reply if not is_chat: