Precise prompts for instruct mode

This commit is contained in:
oobabooga 2023-04-26 03:21:53 -03:00 committed by GitHub
parent a8409426d7
commit a777c058af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 71 additions and 34 deletions

View File

@ -1,3 +1,4 @@
name: "### Response:" name: "### Response:"
your_name: "### Instruction:" your_name: "### Instruction:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." context: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_template: "<|user|>\n<|user-message|>\n\n<|bot|>\n<|bot-message|>\n\n"

View File

@ -1,3 +1,4 @@
name: "答:" name: "答:"
your_name: "[Round <|round|>]\n问:" your_name: "[Round <|round|>]\n问:"
context: "" context: ""
turn_template: "<|user|><|user-message|>\n<|bot|><|bot-message|>\n"

View File

@ -1,3 +1,4 @@
name: "GPT:" name: "GPT:"
your_name: "USER:" your_name: "USER:"
context: "BEGINNING OF CONVERSATION:" context: "BEGINNING OF CONVERSATION: "
turn_template: "<|user|> <|user-message|> <|bot|><|bot-message|></s>"

View File

@ -1,3 +1,4 @@
name: "### Assistant" name: "### Assistant"
your_name: "### Human" your_name: "### Human"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n" context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
turn_template: "<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n"

View File

@ -1,3 +1,3 @@
name: "<|assistant|>" name: "<|assistant|>"
your_name: "<|prompter|>" your_name: "<|prompter|>"
end_of_turn: "<|endoftext|>" turn_template: "<|user|><|user-message|><|endoftext|><|bot|><|bot-message|><|endoftext|>"

View File

@ -0,0 +1,3 @@
name: "Alice:"
your_name: "Bob:"
turn_template: "<|user|> <|user-message|>\n\n<|bot|><|bot-message|>\n\n"

View File

@ -0,0 +1,4 @@
name: "### Assistant:"
your_name: "### Human:"
context: "A chat between a human and an assistant.\n\n"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n"

View File

@ -1,3 +1,4 @@
name: "### Assistant:" name: "ASSISTANT:"
your_name: "### Human:" your_name: "USER:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." context: "A chat between a user and an assistant.\n\n"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"

View File

@ -52,3 +52,6 @@ llama-[0-9]*b-4bit$:
mode: 'instruct' mode: 'instruct'
model_type: 'llama' model_type: 'llama'
instruction_template: 'LLaVA' instruction_template: 'LLaVA'
.*raven:
mode: 'instruct'
instruction_template: 'RWKV-Raven'

View File

@ -17,12 +17,20 @@ from modules.text_generation import (encode, generate_reply,
get_max_prompt_length) get_max_prompt_length)
# Replace multiple string pairs in a string
def replace_all(text, dic):
for i, j in dic.items():
text = text.replace(i, j)
return text
def generate_chat_prompt(user_input, state, **kwargs): def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False _continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
rows = [f"{state['context'].strip()}\n"] rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"]
min_rows = 3 min_rows = 3
# Finding the maximum prompt size # Finding the maximum prompt size
@ -31,38 +39,50 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size) max_length = min(get_max_prompt_length(state), chat_prompt_size)
if is_instruct:
prefix1 = f"{state['name1']}\n"
prefix2 = f"{state['name2']}\n"
else:
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
# Building the turn templates
if 'turn_template' not in state or state['turn_template'] == '':
if is_instruct:
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
else:
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
else:
template = state['turn_template'].replace(r'\n', '\n')
replacements = {
'<|user|>': state['name1'].strip(),
'<|bot|>': state['name2'].strip(),
}
user_turn = replace_all(template.split('<|bot|>')[0], replacements)
bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements)
user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements)
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length: while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1: if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
else: else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
string = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
this_prefix1 = prefix1.replace('<|round|>', f'{i}') # for ChatGLM rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
rows.insert(1, f"{this_prefix1}{string.strip()}{state['end_of_turn']}\n")
i -= 1 i -= 1
if impersonate: if impersonate:
min_rows = 2 min_rows = 2
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(user_turn_stripped)
elif not _continue: elif not _continue:
# Adding the user message # Adding the user message
if len(user_input) > 0: if len(user_input) > 0:
this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}")) rows.append(apply_extensions("bot_prefix", bot_turn_stripped))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1) rows.pop(1)
@ -416,7 +436,7 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, mode): def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
context = greeting = end_of_turn = "" context = greeting = turn_template = ""
greeting_field = 'greeting' greeting_field = 'greeting'
picture = None picture = None
@ -445,7 +465,9 @@ def load_character(character, name1, name2, mode):
data[field] = replace_character_names(data[field], name1, name2) data[field] = replace_character_names(data[field], name1, name2)
if 'context' in data: if 'context' in data:
context = f"{data['context'].strip()}\n\n" context = data['context']
if mode != 'instruct':
context = context.strip() + '\n\n'
elif "char_persona" in data: elif "char_persona" in data:
context = build_pygmalion_style_context(data) context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting' greeting_field = 'char_greeting'
@ -456,14 +478,14 @@ def load_character(character, name1, name2, mode):
if greeting_field in data: if greeting_field in data:
greeting = data[greeting_field] greeting = data[greeting_field]
if 'end_of_turn' in data: if 'turn_template' in data:
end_of_turn = data['end_of_turn'] turn_template = data['turn_template']
else: else:
context = shared.settings['context'] context = shared.settings['context']
name2 = shared.settings['name2'] name2 = shared.settings['name2']
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn'] turn_template = shared.settings['turn_template']
if mode != 'instruct': if mode != 'instruct':
shared.history['internal'] = [] shared.history['internal'] = []
@ -479,7 +501,7 @@ def load_character(character, name1, name2, mode):
# Create .json log files since they don't already exist # Create .json log files since they don't already exist
save_history(mode) save_history(mode)
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode) return name1, name2, picture, greeting, context, repr(turn_template)[1:-1], chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):

View File

@ -39,7 +39,7 @@ settings = {
'name2': 'Assistant', 'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': '', 'greeting': '',
'end_of_turn': '', 'turn_template': '',
'custom_stopping_strings': '', 'custom_stopping_strings': '',
'stop_at_newline': False, 'stop_at_newline': False,
'add_bos_token': True, 'add_bos_token': True,

View File

@ -35,7 +35,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False): def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu'] elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
if chat: if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
elements += list_model_elements() elements += list_model_elements()
return elements return elements

View File

@ -553,7 +553,7 @@ def create_interface():
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings['end_of_turn'], lines=1, label='End of turn string') shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
with gr.Column(scale=1): with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil') shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
@ -778,7 +778,7 @@ def create_interface():
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['instruction_template'].change( shared.gradio['instruction_template'].change(
chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload( shared.gradio['upload_chat_history'].upload(
@ -791,7 +791,7 @@ def create_interface():
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download']) shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")

View File

@ -8,7 +8,7 @@
"name2": "Assistant", "name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"greeting": "", "greeting": "",
"end_of_turn": "", "turn_template": "",
"custom_stopping_strings": "", "custom_stopping_strings": "",
"stop_at_newline": false, "stop_at_newline": false,
"add_bos_token": true, "add_bos_token": true,