mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Add even more sliders
This commit is contained in:
parent
24dc705eca
commit
1c30e1b49a
68
server.py
68
server.py
@ -169,6 +169,11 @@ def load_preset_values(preset_menu, return_dict=False):
|
|||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'repetition_penalty': 1,
|
'repetition_penalty': 1,
|
||||||
'top_k': 50,
|
'top_k': 50,
|
||||||
|
'num_beams': 1,
|
||||||
|
'min_length': 0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'early_stopping': False,
|
||||||
}
|
}
|
||||||
with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
|
with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
|
||||||
preset = infile.read()
|
preset = infile.read()
|
||||||
@ -182,7 +187,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
|||||||
if return_dict:
|
if return_dict:
|
||||||
return settings
|
return settings
|
||||||
else:
|
else:
|
||||||
return settings['do_sample'], settings['temperature'], settings['top_p'], settings['typical_p'], settings['repetition_penalty'], settings['top_k']
|
return settings['do_sample'], settings['temperature'], settings['top_p'], settings['typical_p'], settings['repetition_penalty'], settings['top_k'], settings['min_length'], settings['no_repeat_ngram_size'], settings['num_beams'], settings['length_penalty'], settings['early_stopping']
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
@ -228,7 +233,7 @@ def formatted_outputs(reply, model_name):
|
|||||||
else:
|
else:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, eos_token=None, stopping_string=None):
|
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||||
global model_name, model, tokenizer
|
global model_name, model, tokenizer
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
@ -262,8 +267,15 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
|||||||
f"typical_p={typical_p}",
|
f"typical_p={typical_p}",
|
||||||
f"repetition_penalty={repetition_penalty}",
|
f"repetition_penalty={repetition_penalty}",
|
||||||
f"top_k={top_k}",
|
f"top_k={top_k}",
|
||||||
|
f"min_length={min_length}",
|
||||||
|
f"no_repeat_ngram_size={no_repeat_ngram_size}",
|
||||||
|
f"num_beams={num_beams}",
|
||||||
|
f"length_penalty={length_penalty}",
|
||||||
|
f"early_stopping={early_stopping}",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
print(generate_params)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
generate_params.append("synced_gpus=True")
|
generate_params.append("synced_gpus=True")
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
@ -373,14 +385,24 @@ def create_settings_menus():
|
|||||||
do_sample = gr.Checkbox(value=defaults['do_sample'], label="do_sample")
|
do_sample = gr.Checkbox(value=defaults['do_sample'], label="do_sample")
|
||||||
temperature = gr.Slider(0.01, 1.99, value=defaults['temperature'], step=0.01, label="temperature")
|
temperature = gr.Slider(0.01, 1.99, value=defaults['temperature'], step=0.01, label="temperature")
|
||||||
top_p = gr.Slider(0.0,1.0,value=defaults['top_p'],step=0.01,label="top_p")
|
top_p = gr.Slider(0.0,1.0,value=defaults['top_p'],step=0.01,label="top_p")
|
||||||
with gr.Column():
|
|
||||||
typical_p = gr.Slider(0.0,1.0,value=defaults['typical_p'],step=0.01,label="typical_p")
|
typical_p = gr.Slider(0.0,1.0,value=defaults['typical_p'],step=0.01,label="typical_p")
|
||||||
|
with gr.Column():
|
||||||
repetition_penalty = gr.Slider(1.0,5.0,value=defaults['repetition_penalty'],step=0.01,label="repetition_penalty")
|
repetition_penalty = gr.Slider(1.0,5.0,value=defaults['repetition_penalty'],step=0.01,label="repetition_penalty")
|
||||||
top_k = gr.Slider(0,200,value=defaults['top_k'],step=1,label="top_k")
|
top_k = gr.Slider(0,200,value=defaults['top_k'],step=1,label="top_k")
|
||||||
|
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=defaults["no_repeat_ngram_size"], label="no_repeat_ngram_size")
|
||||||
|
|
||||||
|
gr.Markdown("Special parameters (only use them if you really need them):")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
num_beams = gr.Slider(0, 20, step=1, value=defaults["num_beams"], label="num_beams")
|
||||||
|
length_penalty = gr.Slider(0, 5, value=defaults["length_penalty"], label="length_penalty")
|
||||||
|
with gr.Column():
|
||||||
|
min_length = gr.Slider(0, 2000, step=1, value=defaults["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
||||||
|
early_stopping = gr.Checkbox(value=defaults["early_stopping"], label="early_stopping")
|
||||||
|
|
||||||
model_menu.change(load_model_wrapper, [model_menu], [])
|
model_menu.change(load_model_wrapper, [model_menu], [])
|
||||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k])
|
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
|
||||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k
|
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping
|
||||||
|
|
||||||
# This gets the new line characters right.
|
# This gets the new line characters right.
|
||||||
def clean_chat_message(text):
|
def clean_chat_message(text):
|
||||||
@ -451,14 +473,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension
|
|||||||
|
|
||||||
return reply, next_character_found, substring_found
|
return reply, next_character_found, substring_found
|
||||||
|
|
||||||
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
original_text = text
|
original_text = text
|
||||||
text = apply_extensions(text, "input")
|
text = apply_extensions(text, "input")
|
||||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||||
history['internal'].append(['', ''])
|
history['internal'].append(['', ''])
|
||||||
history['visible'].append(['', ''])
|
history['visible'].append(['', ''])
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||||
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
|
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
|
||||||
history['internal'][-1] = [text, reply]
|
history['internal'][-1] = [text, reply]
|
||||||
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
||||||
@ -468,10 +490,10 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
|
|||||||
break
|
break
|
||||||
yield history['visible']
|
yield history['visible']
|
||||||
|
|
||||||
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||||
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
|
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
|
||||||
if not substring_found:
|
if not substring_found:
|
||||||
yield apply_extensions(reply, "output")
|
yield apply_extensions(reply, "output")
|
||||||
@ -479,19 +501,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
|
|||||||
break
|
break
|
||||||
yield apply_extensions(reply, "output")
|
yield apply_extensions(reply, "output")
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
yield generate_chat_html(_history, name1, name2, character)
|
yield generate_chat_html(_history, name1, name2, character)
|
||||||
|
|
||||||
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
last = history['visible'].pop()
|
last = history['visible'].pop()
|
||||||
history['internal'].pop()
|
history['internal'].pop()
|
||||||
text = last[0]
|
text = last[0]
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
yield i
|
yield i
|
||||||
else:
|
else:
|
||||||
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size):
|
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
def remove_last_message(name1, name2):
|
def remove_last_message(name1, name2):
|
||||||
@ -749,7 +771,7 @@ if args.chat or args.cai_chat:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size'])
|
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size'])
|
||||||
|
|
||||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k = create_settings_menus()
|
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||||
|
|
||||||
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
|
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||||
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||||
@ -787,7 +809,7 @@ if args.chat or args.cai_chat:
|
|||||||
if args.extensions is not None:
|
if args.extensions is not None:
|
||||||
create_extensions_block()
|
create_extensions_block()
|
||||||
|
|
||||||
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, name1, name2, context, check, history_size_slider]
|
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size_slider]
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen"))
|
gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen"))
|
||||||
gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream))
|
gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream))
|
||||||
@ -834,13 +856,13 @@ elif args.notebook:
|
|||||||
|
|
||||||
max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
||||||
|
|
||||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k = create_settings_menus()
|
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||||
|
|
||||||
if args.extensions is not None:
|
if args.extensions is not None:
|
||||||
create_extensions_block()
|
create_extensions_block()
|
||||||
|
|
||||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k], [textbox, markdown, html], show_progress=args.no_stream))
|
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream))
|
||||||
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -857,7 +879,7 @@ else:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
buttons["Stop"] = gr.Button("Stop")
|
buttons["Stop"] = gr.Button("Stop")
|
||||||
|
|
||||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k = create_settings_menus()
|
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||||
if args.extensions is not None:
|
if args.extensions is not None:
|
||||||
create_extensions_block()
|
create_extensions_block()
|
||||||
|
|
||||||
@ -869,9 +891,9 @@ else:
|
|||||||
with gr.Tab('HTML'):
|
with gr.Tab('HTML'):
|
||||||
html = gr.HTML()
|
html = gr.HTML()
|
||||||
|
|
||||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k], [output_textbox, markdown, html], show_progress=args.no_stream))
|
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||||
gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k], [output_textbox, markdown, html], show_progress=args.no_stream))
|
gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||||
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
||||||
|
|
||||||
interface.queue()
|
interface.queue()
|
||||||
|
Loading…
Reference in New Issue
Block a user