Set chat prompt size in tokens

This commit is contained in:
oobabooga 2023-02-15 10:18:50 -03:00
parent 1622059179
commit 7be372829d
2 changed files with 19 additions and 21 deletions

View File

@ -71,9 +71,9 @@ settings = {
'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'stop_at_newline': True, 'stop_at_newline': True,
'history_size': 0, 'chat_prompt_size': 2048,
'history_size_min': 0, 'chat_prompt_size_min': 0,
'history_size_max': 64, 'chat_prompt_size_max': 2048,
'preset_pygmalion': 'Pygmalion', 'preset_pygmalion': 'Pygmalion',
'name1_pygmalion': 'You', 'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii', 'name2_pygmalion': 'Kawaii',
@ -503,13 +503,13 @@ def clean_chat_message(text):
text = text.strip() text = text.strip()
return text return text
def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=False): def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
text = clean_chat_message(text) text = clean_chat_message(text)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
i = len(history['internal'])-1 i = len(history['internal'])-1
count = 0 count = 0
max_length = get_max_prompt_length(tokens) max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
count += 1 count += 1
@ -517,8 +517,6 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n") rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
count += 1 count += 1
i -= 1 i -= 1
if history_size != 0 and count >= history_size:
break
if not impersonate: if not impersonate:
rows.append(f"{name1}: {text}\n") rows.append(f"{name1}: {text}\n")
@ -566,14 +564,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
if args.picture and picture is not None: if args.picture and picture is not None:
text, visible_text = generate_chat_picture(picture, name1, name2) text, visible_text = generate_chat_picture(picture, name1, name2)
else: else:
visible_text = text visible_text = text
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size)
history['internal'].append(['', '']) history['internal'].append(['', ''])
history['visible'].append(['', '']) history['visible'].append(['', ''])
eos_token = '\n' if check else None eos_token = '\n' if check else None
@ -587,8 +585,8 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
break break
yield history['visible'] yield history['visible']
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True)
eos_token = '\n' if check else None eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
@ -598,19 +596,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
break break
yield apply_extensions(reply, "output") yield apply_extensions(reply, "output")
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield generate_chat_html(_history, name1, name2, character) yield generate_chat_html(_history, name1, name2, character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
last = history['visible'].pop() last = history['visible'].pop()
history['internal'].pop() history['internal'].pop()
text = last[0] text = last[0]
if args.cai_chat: if args.cai_chat:
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield i yield i
else: else:
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield i yield i
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
@ -886,7 +884,7 @@ if args.chat or args.cai_chat:
with gr.Column(): with gr.Column():
max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Column(): with gr.Column():
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size']) chat_prompt_size_slider = gr.Slider(minimum=settings['chat_prompt_size_min'], maximum=settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=settings['chat_prompt_size'])
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
@ -926,7 +924,7 @@ if args.chat or args.cai_chat:
if args.extensions is not None: if args.extensions is not None:
create_extensions_block() create_extensions_block()
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider]
if args.picture: if args.picture:
input_params.append(picture_select) input_params.append(picture_select)
if args.cai_chat: if args.cai_chat:

View File

@ -9,9 +9,9 @@
"prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n", "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"stop_at_newline": true, "stop_at_newline": true,
"history_size": 0, "chat_prompt_size": 2048,
"history_size_min": 0, "chat_prompt_size_min": 0,
"history_size_max": 64, "chat_prompt_size_max": 2048,
"preset_pygmalion": "Pygmalion", "preset_pygmalion": "Pygmalion",
"name1_pygmalion": "You", "name1_pygmalion": "You",
"name2_pygmalion": "Kawaii", "name2_pygmalion": "Kawaii",