From a777c058aff905e142700b912cd8e3ddbd3fe8e1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 26 Apr 2023 03:21:53 -0300 Subject: [PATCH] Precise prompts for instruct mode --- characters/instruction-following/Alpaca.yaml | 3 +- characters/instruction-following/ChatGLM.yaml | 1 + characters/instruction-following/Koala.yaml | 3 +- characters/instruction-following/LLaVA.yaml | 3 +- .../instruction-following/Open Assistant.yaml | 2 +- .../instruction-following/RWKV-Raven.yaml | 3 + .../instruction-following/Vicuna-v0.yaml | 4 ++ characters/instruction-following/Vicuna.yaml | 7 +- models/config.yaml | 3 + modules/chat.py | 64 +++++++++++++------ modules/shared.py | 2 +- modules/ui.py | 2 +- server.py | 6 +- settings-template.json | 2 +- 14 files changed, 71 insertions(+), 34 deletions(-) create mode 100644 characters/instruction-following/RWKV-Raven.yaml create mode 100644 characters/instruction-following/Vicuna-v0.yaml diff --git a/characters/instruction-following/Alpaca.yaml b/characters/instruction-following/Alpaca.yaml index 30373242..4379d02b 100644 --- a/characters/instruction-following/Alpaca.yaml +++ b/characters/instruction-following/Alpaca.yaml @@ -1,3 +1,4 @@ name: "### Response:" your_name: "### Instruction:" -context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." +context: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" +turn_template: "<|user|>\n<|user-message|>\n\n<|bot|>\n<|bot-message|>\n\n" diff --git a/characters/instruction-following/ChatGLM.yaml b/characters/instruction-following/ChatGLM.yaml index 02a26855..aed51eb9 100644 --- a/characters/instruction-following/ChatGLM.yaml +++ b/characters/instruction-following/ChatGLM.yaml @@ -1,3 +1,4 @@ name: "答:" your_name: "[Round <|round|>]\n问:" context: "" +turn_template: "<|user|><|user-message|>\n<|bot|><|bot-message|>\n" diff --git a/characters/instruction-following/Koala.yaml b/characters/instruction-following/Koala.yaml index 18dc7b9b..c0f64098 100644 --- a/characters/instruction-following/Koala.yaml +++ b/characters/instruction-following/Koala.yaml @@ -1,3 +1,4 @@ name: "GPT:" your_name: "USER:" -context: "BEGINNING OF CONVERSATION:" +context: "BEGINNING OF CONVERSATION: " +turn_template: "<|user|> <|user-message|> <|bot|><|bot-message|>" diff --git a/characters/instruction-following/LLaVA.yaml b/characters/instruction-following/LLaVA.yaml index b3999e46..7a8cf1d3 100644 --- a/characters/instruction-following/LLaVA.yaml +++ b/characters/instruction-following/LLaVA.yaml @@ -1,3 +1,4 @@ name: "### Assistant" your_name: "### Human" -context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n" \ No newline at end of file +context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n" +turn_template: "<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n" \ No newline at end of file diff --git a/characters/instruction-following/Open Assistant.yaml b/characters/instruction-following/Open Assistant.yaml index 5b3320ff..bf199517 100644 --- a/characters/instruction-following/Open Assistant.yaml +++ b/characters/instruction-following/Open Assistant.yaml @@ -1,3 +1,3 @@ name: "<|assistant|>" your_name: "<|prompter|>" -end_of_turn: "<|endoftext|>" +turn_template: "<|user|><|user-message|><|endoftext|><|bot|><|bot-message|><|endoftext|>" diff --git a/characters/instruction-following/RWKV-Raven.yaml b/characters/instruction-following/RWKV-Raven.yaml new file mode 100644 index 00000000..867b8705 --- /dev/null +++ b/characters/instruction-following/RWKV-Raven.yaml @@ -0,0 +1,3 @@ +name: "Alice:" +your_name: "Bob:" +turn_template: "<|user|> <|user-message|>\n\n<|bot|><|bot-message|>\n\n" diff --git a/characters/instruction-following/Vicuna-v0.yaml b/characters/instruction-following/Vicuna-v0.yaml new file mode 100644 index 00000000..43b2a28c --- /dev/null +++ b/characters/instruction-following/Vicuna-v0.yaml @@ -0,0 +1,4 @@ +name: "### Assistant:" +your_name: "### Human:" +context: "A chat between a human and an assistant.\n\n" +turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" diff --git a/characters/instruction-following/Vicuna.yaml b/characters/instruction-following/Vicuna.yaml index 026901d4..9b00b764 100644 --- a/characters/instruction-following/Vicuna.yaml +++ b/characters/instruction-following/Vicuna.yaml @@ -1,3 +1,4 @@ -name: "### Assistant:" -your_name: "### Human:" -context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." +name: "ASSISTANT:" +your_name: "USER:" +context: "A chat between a user and an assistant.\n\n" +turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" diff --git a/models/config.yaml b/models/config.yaml index e9aa3a55..e43c1648 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -52,3 +52,6 @@ llama-[0-9]*b-4bit$: mode: 'instruct' model_type: 'llama' instruction_template: 'LLaVA' +.*raven: + mode: 'instruct' + instruction_template: 'RWKV-Raven' diff --git a/modules/chat.py b/modules/chat.py index efdc0de8..299ecf38 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -17,12 +17,20 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) +# Replace multiple string pairs in a string +def replace_all(text, dic): + for i, j in dic.items(): + text = text.replace(i, j) + + return text + + def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False is_instruct = state['mode'] == 'instruct' - rows = [f"{state['context'].strip()}\n"] + rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"] min_rows = 3 # Finding the maximum prompt size @@ -31,38 +39,50 @@ def generate_chat_prompt(user_input, state, **kwargs): chat_prompt_size -= shared.soft_prompt_tensor.shape[1] max_length = min(get_max_prompt_length(state), chat_prompt_size) - if is_instruct: - prefix1 = f"{state['name1']}\n" - prefix2 = f"{state['name2']}\n" - else: - prefix1 = f"{state['name1']}: " - prefix2 = f"{state['name2']}: " + # Building the turn templates + if 'turn_template' not in state or state['turn_template'] == '': + if is_instruct: + template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n' + else: + template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n' + else: + template = state['turn_template'].replace(r'\n', '\n') + + replacements = { + '<|user|>': state['name1'].strip(), + '<|bot|>': state['name2'].strip(), + } + + user_turn = replace_all(template.split('<|bot|>')[0], replacements) + bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements) + user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements) + bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) + + # Building the prompt i = len(shared.history['internal']) - 1 while i >= 0 and len(encode(''.join(rows))[0]) < max_length: if _continue and i == len(shared.history['internal']) - 1: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") + rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip()) else: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") + rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip())) string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - this_prefix1 = prefix1.replace('<|round|>', f'{i}') # for ChatGLM - rows.insert(1, f"{this_prefix1}{string.strip()}{state['end_of_turn']}\n") + rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)})) i -= 1 if impersonate: min_rows = 2 - rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") + rows.append(user_turn_stripped) elif not _continue: # Adding the user message if len(user_input) > 0: - this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM - rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n") + rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))})) # Adding the Character prefix - rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}")) + rows.append(apply_extensions("bot_prefix", bot_turn_stripped)) while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: rows.pop(1) @@ -416,7 +436,7 @@ def generate_pfp_cache(character): def load_character(character, name1, name2, mode): shared.character = character - context = greeting = end_of_turn = "" + context = greeting = turn_template = "" greeting_field = 'greeting' picture = None @@ -445,7 +465,9 @@ def load_character(character, name1, name2, mode): data[field] = replace_character_names(data[field], name1, name2) if 'context' in data: - context = f"{data['context'].strip()}\n\n" + context = data['context'] + if mode != 'instruct': + context = context.strip() + '\n\n' elif "char_persona" in data: context = build_pygmalion_style_context(data) greeting_field = 'char_greeting' @@ -456,14 +478,14 @@ def load_character(character, name1, name2, mode): if greeting_field in data: greeting = data[greeting_field] - if 'end_of_turn' in data: - end_of_turn = data['end_of_turn'] + if 'turn_template' in data: + turn_template = data['turn_template'] else: context = shared.settings['context'] name2 = shared.settings['name2'] greeting = shared.settings['greeting'] - end_of_turn = shared.settings['end_of_turn'] + turn_template = shared.settings['turn_template'] if mode != 'instruct': shared.history['internal'] = [] @@ -479,7 +501,7 @@ def load_character(character, name1, name2, mode): # Create .json log files since they don't already exist save_history(mode) - return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode) + return name1, name2, picture, greeting, context, repr(turn_template)[1:-1], chat_html_wrapper(shared.history['visible'], name1, name2, mode) def upload_character(json_file, img, tavern=False): diff --git a/modules/shared.py b/modules/shared.py index 4881ffa5..849b9cef 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -39,7 +39,7 @@ settings = { 'name2': 'Assistant', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'greeting': '', - 'end_of_turn': '', + 'turn_template': '', 'custom_stopping_strings': '', 'stop_at_newline': False, 'add_bos_token': True, diff --git a/modules/ui.py b/modules/ui.py index 0d62ab3c..c40e596b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -35,7 +35,7 @@ def list_model_elements(): def list_interface_input_elements(chat=False): elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu'] if chat: - elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] + elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] elements += list_model_elements() return elements diff --git a/server.py b/server.py index b7d6aa34..b8566aed 100644 --- a/server.py +++ b/server.py @@ -553,7 +553,7 @@ def create_interface(): shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') - shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings['end_of_turn'], lines=1, label='End of turn string') + shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.') with gr.Column(scale=1): shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil') @@ -778,7 +778,7 @@ def create_interface(): chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['instruction_template'].change( - chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( + chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']]).then( chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['upload_chat_history'].upload( @@ -791,7 +791,7 @@ def create_interface(): shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download']) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") diff --git a/settings-template.json b/settings-template.json index 55032aa9..9465d799 100644 --- a/settings-template.json +++ b/settings-template.json @@ -8,7 +8,7 @@ "name2": "Assistant", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "greeting": "", - "end_of_turn": "", + "turn_template": "", "custom_stopping_strings": "", "stop_at_newline": false, "add_bos_token": true,