mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Stop generating at \n in chat mode
Makes it a lot more efficient.
This commit is contained in:
parent
a9280dde52
commit
f2a548c098
14
server.py
14
server.py
@ -69,7 +69,7 @@ def fix_galactica(s):
|
|||||||
s = s.replace(r'$$', r'$')
|
s = s.replace(r'$$', r'$')
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def generate_reply(question, temperature, max_length, inference_settings, selected_model):
|
def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
|
|
||||||
if selected_model != model_name:
|
if selected_model != model_name:
|
||||||
@ -86,7 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
|
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
|
||||||
|
|
||||||
output = eval(f"model.generate(input_ids, {preset}).cuda()")
|
if eos_token is None:
|
||||||
|
output = eval(f"model.generate(input_ids, {preset}).cuda()")
|
||||||
|
else:
|
||||||
|
n = tokenizer.encode(eos_token, return_tensors='pt')[0][1]
|
||||||
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}).cuda()")
|
||||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
|
||||||
if model_name.lower().startswith('galactica'):
|
if model_name.lower().startswith('galactica'):
|
||||||
@ -159,7 +163,7 @@ elif args.chat:
|
|||||||
question += f"{name1}: {text.strip()}\n"
|
question += f"{name1}: {text.strip()}\n"
|
||||||
question += f"{name2}:"
|
question += f"{name2}:"
|
||||||
|
|
||||||
reply = generate_reply(question, temperature, max_length, inference_settings, selected_model)[0]
|
reply = generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token='\n')[0]
|
||||||
reply = reply[len(question):].split('\n')[0].strip()
|
reply = reply[len(question):].split('\n')[0].strip()
|
||||||
history.append((text, reply))
|
history.append((text, reply))
|
||||||
return history
|
return history
|
||||||
@ -175,7 +179,7 @@ elif args.chat:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100)
|
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
|
||||||
preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
|
preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
|
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
|
||||||
@ -203,7 +207,7 @@ else:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
|
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
|
||||||
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
|
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
|
||||||
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100)
|
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
|
||||||
preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
|
preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
|
||||||
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
|
||||||
btn = gr.Button("Generate")
|
btn = gr.Button("Generate")
|
||||||
|
Loading…
Reference in New Issue
Block a user