From 1d8b8222e95e7669dd431a60866ebdc786ea76dd Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 26 Apr 2023 16:47:50 -0300 Subject: [PATCH] Revert #1579, apply the proper fix Apparently models dislike trailing spaces. --- characters/instruction-following/Vicuna-v0.yaml | 2 +- characters/instruction-following/Vicuna.yaml | 2 +- modules/chat.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/characters/instruction-following/Vicuna-v0.yaml b/characters/instruction-following/Vicuna-v0.yaml index 5312ca77..43b2a28c 100644 --- a/characters/instruction-following/Vicuna-v0.yaml +++ b/characters/instruction-following/Vicuna-v0.yaml @@ -1,4 +1,4 @@ name: "### Assistant:" your_name: "### Human:" context: "A chat between a human and an assistant.\n\n" -turn_template: "<|user|> <|user-message|>\n<|bot|><|bot-message|>\n" +turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" diff --git a/characters/instruction-following/Vicuna.yaml b/characters/instruction-following/Vicuna.yaml index 627dcb83..9b00b764 100644 --- a/characters/instruction-following/Vicuna.yaml +++ b/characters/instruction-following/Vicuna.yaml @@ -1,4 +1,4 @@ name: "ASSISTANT:" your_name: "USER:" context: "A chat between a user and an assistant.\n\n" -turn_template: "<|user|> <|user-message|>\n<|bot|><|bot-message|>\n" +turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" diff --git a/modules/chat.py b/modules/chat.py index 299ecf38..2969fd6a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -56,8 +56,8 @@ def generate_chat_prompt(user_input, state, **kwargs): user_turn = replace_all(template.split('<|bot|>')[0], replacements) bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements) - user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements) - bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) + user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements).rstrip(' ') + bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements).rstrip(' ') # Building the prompt i = len(shared.history['internal']) - 1