diff --git a/server.py b/server.py index ce38c66f..d65998bb 100644 --- a/server.py +++ b/server.py @@ -604,12 +604,9 @@ def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top last = history['visible'].pop() history['internal'].pop() text = last[0] - if args.cai_chat: - for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): - yield i - else: - for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): - yield i + function_call = "cai_chatbot_wrapper" if args.cai_chat else "chatbot_wrapper" + for i in eval(function_call)(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): + yield i def remove_last_message(name1, name2): if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': @@ -937,27 +934,24 @@ if args.chat or args.cai_chat: input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] if args.picture: input_params.append(picture_select) - if args.cai_chat: - gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)) - if args.picture: - picture_select.upload(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream) - else: - gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(chatbot_wrapper, input_params, display, show_progress=args.no_stream)) - if args.picture: - picture_select.upload(chatbot_wrapper, input_params, display, show_progress=args.no_stream) + function_call = "cai_chatbot_wrapper" if args.cai_chat else "chatbot_wrapper" + + gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=args.no_stream)) + if args.picture: + picture_select.upload(eval(function_call), input_params, display, show_progress=args.no_stream) gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream)) + buttons["Stop"].click(None, None, None, cancels=gen_events) buttons["Send last reply to input"].click(send_last_reply_to_input, [], textbox, show_progress=args.no_stream) buttons["Replace last reply"].click(replace_last_reply, [textbox, name1, name2], display, show_progress=args.no_stream) buttons["Clear"].click(clear_chat_log, [character_menu, name1, name2], display) buttons["Remove last"].click(remove_last_message, [name1, name2], [display, textbox], show_progress=False) - buttons["Stop"].click(None, None, None, cancels=gen_events) buttons["Download"].click(save_history, inputs=[], outputs=[download]) buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu]) + # Clearing stuff and saving the history for i in ["Generate", "Regenerate", "Replace last reply"]: buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) buttons[i].click(lambda : save_history(timestamp=False), [], [], show_progress=False) @@ -970,7 +964,6 @@ if args.chat or args.cai_chat: upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) if args.picture: picture_select.upload(lambda : None, [], [picture_select], show_progress=False) - if args.cai_chat: upload.upload(redraw_html, [name1, name2], [display]) upload_img_me.upload(redraw_html, [name1, name2], [display])