From 185587a33ec9882be9bad525e3abd5b87a516998 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 20 Jan 2023 17:03:09 -0300 Subject: [PATCH] Add a history size parameter to the chat If too many messages are used in the prompt, the model gets really slow. It is useful to have the ability to limit this. --- server.py | 31 +++++++++++++++++++++---------- settings-template.json | 3 +++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/server.py b/server.py index 3c1a0fa7..c7b3560f 100644 --- a/server.py +++ b/server.py @@ -49,6 +49,9 @@ settings = { 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', 'stop_at_newline': True, + 'history_size': 8, + 'history_size_min': 0, + 'history_size_max': 64, 'preset_pygmalion': 'Pygmalion', 'name1_pygmalion': 'You', 'name2_pygmalion': 'Kawaii', @@ -229,16 +232,21 @@ if args.chat or args.cai_chat: text = text.strip() return text - def generate_chat_prompt(text, tokens, name1, name2, context): + def generate_chat_prompt(text, tokens, name1, name2, context, history_size): text = clean_chat_message(text) rows = [f"{context.strip()}\n"] i = len(history)-1 + count = 0 while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: rows.insert(1, f"{name2}: {history[i][1].strip()}\n") + count += 1 if not (i == 0 and len(history[i][0]) == 0): rows.insert(1, f"{name1}: {history[i][0].strip()}\n") + count += 1 i -= 1 + if history_size != 0 and count >= history_size: + break rows.append(f"{name1}: {text}\n") rows.append(f"{name2}:") @@ -247,10 +255,11 @@ if args.chat or args.cai_chat: rows.pop(1) question = ''.join(rows) + print(question) return question - def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): - question = generate_chat_prompt(text, tokens, name1, name2, context) + def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) history.append(['', '']) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): @@ -288,8 +297,8 @@ if args.chat or args.cai_chat: yield history - def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): - for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): yield generate_chat_html(history, name1, name2, character) def remove_last_message(name1, name2): @@ -362,11 +371,12 @@ if args.chat or args.cai_chat: stop = gr.Button("Stop") btn3 = gr.Button("Remove last message") - length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) with gr.Row(): with gr.Column(): + length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') with gr.Column(): + history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size']) preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset') name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name') @@ -385,13 +395,14 @@ if args.chat or args.cai_chat: save_btn = gr.Button(value="Click me") download = gr.File() + input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider] if args.cai_chat: - gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream) + gen_event = btn.click(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream) btn2.click(clear_html, [], display1, show_progress=False) else: - gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream) + gen_event = btn.click(chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(chatbot_wrapper, input_params, display1, show_progress=args.no_stream) btn2.click(lambda x: "", display1, display1, show_progress=False) btn2.click(clear) diff --git a/settings-template.json b/settings-template.json index 25fad95e..1199f104 100644 --- a/settings-template.json +++ b/settings-template.json @@ -9,6 +9,9 @@ "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n", "stop_at_newline": true, + "history_size": 8, + "history_size_min": 0, + "history_size_max": 64, "preset_pygmalion": "Pygmalion", "name1_pygmalion": "You", "name2_pygmalion": "Kawaii",