mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Add no-stream option
This commit is contained in:
parent
116299b3ad
commit
99536ef5bf
39
server.py
39
server.py
@ -25,6 +25,7 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
|
||||
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -139,6 +140,7 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
preset = infile.read()
|
||||
loaded_preset = inference_settings
|
||||
|
||||
if not args.no_stream:
|
||||
input_ids = encode(question, 1)
|
||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||
cuda = "" if args.cpu else ".cuda()"
|
||||
@ -158,6 +160,25 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
|
||||
input_ids = output
|
||||
else:
|
||||
input_ids = encode(question, tokens)
|
||||
cuda = "" if args.cpu else ".cuda()"
|
||||
if eos_token is None:
|
||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||
else:
|
||||
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
if model_name.lower().startswith('galactica'):
|
||||
reply = fix_galactica(reply)
|
||||
yield reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
|
||||
|
||||
# Choosing the default model
|
||||
if args.model is not None:
|
||||
@ -305,12 +326,12 @@ if args.chat or args.cai_chat:
|
||||
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||
|
||||
if args.cai_chat:
|
||||
gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||
gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||
gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
|
||||
gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
|
||||
btn2.click(clear_html, [], display1, show_progress=False)
|
||||
else:
|
||||
gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||
gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||
gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream, api_name="textgen")
|
||||
gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=args.no_stream)
|
||||
btn2.click(lambda x: "", display1, display1, show_progress=False)
|
||||
|
||||
btn2.click(clear)
|
||||
@ -338,8 +359,8 @@ elif args.notebook:
|
||||
with gr.Column():
|
||||
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
||||
|
||||
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
|
||||
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
|
||||
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
|
||||
stop.click(None, None, None, cancels=[gen_event, gen_event2])
|
||||
|
||||
else:
|
||||
@ -365,9 +386,9 @@ else:
|
||||
with gr.Tab('HTML'):
|
||||
html = gr.HTML()
|
||||
|
||||
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
|
||||
cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
|
||||
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
|
||||
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)
|
||||
cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)
|
||||
stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event])
|
||||
|
||||
interface.queue()
|
||||
|
Loading…
x
Reference in New Issue
Block a user