From 99536ef5bfe1d3a23162b1f0d83d28e40f900442 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 18 Jan 2023 23:56:42 -0300 Subject: [PATCH] Add no-stream option --- server.py | 55 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/server.py b/server.py index c595a634..a18120ee 100644 --- a/server.py +++ b/server.py @@ -25,6 +25,7 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.') +parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') args = parser.parse_args() @@ -139,15 +140,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = infile.read() loaded_preset = inference_settings - input_ids = encode(question, 1) - preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') - cuda = "" if args.cpu else ".cuda()" - for i in range(tokens): - output = eval(f"model.generate(input_ids, {preset}){cuda}") + if not args.no_stream: + input_ids = encode(question, 1) + preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') + cuda = "" if args.cpu else ".cuda()" + for i in range(tokens): + output = eval(f"model.generate(input_ids, {preset}){cuda}") + reply = tokenizer.decode(output[0], skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + if eos_token is not None and reply[-1] == eos_token: + break + if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) + yield reply, reply, generate_basic_html(reply) + elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) + yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + else: + yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + + input_ids = output + else: + input_ids = encode(question, tokens) + cuda = "" if args.cpu else ".cuda()" + if eos_token is None: + output = eval(f"model.generate(input_ids, {preset}){cuda}") + else: + n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] + output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = reply.replace(r'<|endoftext|>', '') - if eos_token is not None and reply[-1] == eos_token: - break if model_name.lower().startswith('galactica'): reply = fix_galactica(reply) yield reply, reply, generate_basic_html(reply) @@ -157,7 +179,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok else: yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) - input_ids = output # Choosing the default model if args.model is not None: @@ -305,12 +326,12 @@ if args.chat or args.cai_chat: check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?') if args.cai_chat: - gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") - gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream) btn2.click(clear_html, [], display1, show_progress=False) else: - gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") - gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream) btn2.click(lambda x: "", display1, display1, show_progress=False) btn2.click(clear) @@ -338,8 +359,8 @@ elif args.notebook: with gr.Column(): preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') - gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen") - gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) + gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream) stop.click(None, None, None, cancels=[gen_event, gen_event2]) else: @@ -365,9 +386,9 @@ else: with gr.Tab('HTML'): html = gr.HTML() - gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen") - gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) - cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) + gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") + gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream) + cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream) stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event]) interface.queue()