Add no-stream option

This commit is contained in:
oobabooga 2023-01-18 23:56:42 -03:00
parent 116299b3ad
commit 99536ef5bf

View File

@ -25,6 +25,7 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.') parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
args = parser.parse_args() args = parser.parse_args()
@ -139,15 +140,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
preset = infile.read() preset = infile.read()
loaded_preset = inference_settings loaded_preset = inference_settings
input_ids = encode(question, 1) if not args.no_stream:
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') input_ids = encode(question, 1)
cuda = "" if args.cpu else ".cuda()" preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
for i in range(tokens): cuda = "" if args.cpu else ".cuda()"
output = eval(f"model.generate(input_ids, {preset}){cuda}") for i in range(tokens):
output = eval(f"model.generate(input_ids, {preset}){cuda}")
reply = tokenizer.decode(output[0], skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
if eos_token is not None and reply[-1] == eos_token:
break
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
yield reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
input_ids = output
else:
input_ids = encode(question, tokens)
cuda = "" if args.cpu else ".cuda()"
if eos_token is None:
output = eval(f"model.generate(input_ids, {preset}){cuda}")
else:
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = tokenizer.decode(output[0], skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '') reply = reply.replace(r'<|endoftext|>', '')
if eos_token is not None and reply[-1] == eos_token:
break
if model_name.lower().startswith('galactica'): if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply) reply = fix_galactica(reply)
yield reply, reply, generate_basic_html(reply) yield reply, reply, generate_basic_html(reply)
@ -157,7 +179,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
else: else:
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
input_ids = output
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -305,12 +326,12 @@ if args.chat or args.cai_chat:
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?') check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
if args.cai_chat: if args.cai_chat:
gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
btn2.click(clear_html, [], display1, show_progress=False) btn2.click(clear_html, [], display1, show_progress=False)
else: else:
gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
btn2.click(lambda x: "", display1, display1, show_progress=False) btn2.click(lambda x: "", display1, display1, show_progress=False)
btn2.click(clear) btn2.click(clear)
@ -338,8 +359,8 @@ elif args.notebook:
with gr.Column(): with gr.Column():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen") gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
stop.click(None, None, None, cancels=[gen_event, gen_event2]) stop.click(None, None, None, cancels=[gen_event, gen_event2])
else: else:
@ -365,9 +386,9 @@ else:
with gr.Tab('HTML'): with gr.Tab('HTML'):
html = gr.HTML() html = gr.HTML()
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen") gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)
cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)
stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event]) stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event])
interface.queue() interface.queue()