diff --git a/server.py b/server.py index 7bec72c4..c838cc0a 100644 --- a/server.py +++ b/server.py @@ -139,25 +139,28 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = infile.read() loaded_preset = inference_settings - input_ids = encode(question, tokens) + for i in range(tokens): + input_ids = encode(question, 1) + preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') - cuda = ".cuda()" if args.cpu else "" - if eos_token is None: - output = eval(f"model.generate(input_ids, {preset}){cuda}") - else: - n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] - output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") + cuda = ".cuda()" if args.cpu else "" + if eos_token is None: + output = eval(f"model.generate(input_ids, {preset}){cuda}") + else: + n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] + output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") - reply = tokenizer.decode(output[0], skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - if model_name.lower().startswith('galactica'): - reply = fix_galactica(reply) - return reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) - else: - return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + reply = tokenizer.decode(output[0], skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + question = reply + if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) + yield reply, reply, generate_basic_html(reply) + elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) + yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + else: + yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) # Choosing the default model if args.model is not None: @@ -205,20 +208,20 @@ if args.notebook: with gr.Column(): preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') - btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen") - textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True) + btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen") + textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) elif args.chat or args.cai_chat: history = [] # This gets the new line characters right. - def chat_response_cleaner(text): + def clean_chat_message(text): text = text.replace('\n', '\n\n') text = re.sub(r"\n{3,}", "\n\n", text) text = text.strip() return text - def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): - text = chat_response_cleaner(text) + def generate_chat_prompt(text, tokens, name1, name2, context): + text = clean_chat_message(text) rows = [f"{context}\n\n"] i = len(history)-1 @@ -234,26 +237,42 @@ elif args.chat or args.cai_chat: rows.pop(1) question = ''.join(rows) + return question - if check: - reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] - idx = reply.rfind(question[-1024:]) - reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip() - else: - reply = generate_reply(question, tokens, inference_settings, selected_model)[0] - idx = reply.rfind(question[-1024:]) - reply = reply[idx+min(1024, len(question)):] - idx = reply.find(f"\n{name1}:") - if idx != -1: - reply = reply[:idx] - reply = chat_response_cleaner(reply) + def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + history.append(['', '']) + question = generate_chat_prompt(text, tokens, name1, name2, context) + eos_token = '\n' if check else None + for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): + reply = i[0] - history.append((text, reply)) - return history + if check: + idx = reply.rfind(question[-1024:]) + reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip() + else: + idx = reply.rfind(question[-1024:]) + reply = reply[idx+min(1024, len(question)):] + idx = reply.find(f"\n{name1}:") + if idx != -1: + reply = reply[:idx] + reply = clean_chat_message(reply) + + history[-1] = [text, reply] + + # Prevent the chat log from flashing if something like "\nYo" is generated just + # before "\nYou:" is completed + tmp = f"\n{name1}:" + found = False + for j in range(1, len(tmp)): + if reply[-j:] == tmp[:j]: + found = True + + if not found: + yield history def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): - history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check) - return generate_chat_html(history, name1, name2) + for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + yield generate_chat_html(history, name1, name2) def remove_last_message(name1, name2): history.pop() @@ -305,13 +324,13 @@ elif args.chat or args.cai_chat: check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?') if args.cai_chat: - btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") - textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) - btn2.click(clear_html, [], display1, show_progress=False) + btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") + textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + btn2.click(clear_html, [], display1, show_progress=True) else: - btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") - textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) - btn2.click(lambda x: "", display1, display1) + btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") + textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + btn2.click(lambda x: "", display1, display1, show_progress=True) btn2.click(clear) btn3.click(remove_last_message, [name1, name2], display1, show_progress=False) @@ -320,8 +339,9 @@ elif args.chat or args.cai_chat: else: def continue_wrapper(question, tokens, inference_settings, selected_model): - a, b, c = generate_reply(question, tokens, inference_settings, selected_model) - return a, a, b, c + for i in generate_reply(question, tokens, inference_settings, selected_model): + a, b, c = i + yield a, a, b, c with gr.Blocks(css=css, analytics_enabled=False) as interface: gr.Markdown(description) @@ -341,10 +361,11 @@ else: with gr.Tab('HTML'): html = gr.HTML() - btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True, api_name="textgen") - cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=True) - textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True) + btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen") + cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False) + textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) +interface.queue() if args.no_listen: interface.launch(share=False) else: