mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Add penalty_alpha parameter (contrastive search)
This commit is contained in:
parent
8aafb55693
commit
0dd1409f24
@ -186,4 +186,5 @@ For these two, please try commenting on an existing issue instead of creating a
|
||||
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
||||
- Pygmalion preset, code for early stopping in chat mode, code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/
|
||||
- Verbose preset: Anonymous 4chan user.
|
||||
- Instruct-Joi preset: https://huggingface.co/Rallio67/joi\_12B\_instruct\_alpha
|
||||
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
|
5
presets/Instruct-Joi.txt
Normal file
5
presets/Instruct-Joi.txt
Normal file
@ -0,0 +1,5 @@
|
||||
top_p=0.95,
|
||||
temperature=0.5,
|
||||
penalty_alpha=0.6,
|
||||
top_k=4,
|
||||
repetition_penalty=1.03,
|
47
server.py
47
server.py
@ -174,6 +174,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
'repetition_penalty': 1,
|
||||
'top_k': 50,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
@ -191,7 +192,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
@ -237,7 +238,7 @@ def formatted_outputs(reply, model_name):
|
||||
else:
|
||||
return reply
|
||||
|
||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
global model_name, model, tokenizer
|
||||
|
||||
original_question = question
|
||||
@ -274,6 +275,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
||||
f"min_length={min_length if args.no_stream else 0}",
|
||||
f"no_repeat_ngram_size={no_repeat_ngram_size}",
|
||||
f"num_beams={num_beams}",
|
||||
f"penalty_alpha={penalty_alpha}",
|
||||
f"length_penalty={length_penalty}",
|
||||
f"early_stopping={early_stopping}",
|
||||
]
|
||||
@ -392,6 +394,7 @@ def create_settings_menus():
|
||||
repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
|
||||
top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
|
||||
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
|
||||
penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha")
|
||||
|
||||
gr.Markdown("Special parameters (only use them if you really need them):")
|
||||
with gr.Row():
|
||||
@ -403,8 +406,8 @@ def create_settings_menus():
|
||||
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
||||
|
||||
model_menu.change(load_model_wrapper, [model_menu], [])
|
||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
|
||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping
|
||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
||||
|
||||
# This gets the new line characters right.
|
||||
def clean_chat_message(text):
|
||||
@ -475,14 +478,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension
|
||||
|
||||
return reply, next_character_found, substring_found
|
||||
|
||||
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
original_text = text
|
||||
text = apply_extensions(text, "input")
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||
history['internal'].append(['', ''])
|
||||
history['visible'].append(['', ''])
|
||||
eos_token = '\n' if check else None
|
||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
|
||||
history['internal'][-1] = [text, reply]
|
||||
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
||||
@ -492,10 +495,10 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
|
||||
break
|
||||
yield history['visible']
|
||||
|
||||
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
|
||||
eos_token = '\n' if check else None
|
||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
|
||||
if not substring_found:
|
||||
yield apply_extensions(reply, "output")
|
||||
@ -503,19 +506,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
|
||||
break
|
||||
yield apply_extensions(reply, "output")
|
||||
|
||||
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
yield generate_chat_html(_history, name1, name2, character)
|
||||
|
||||
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
last = history['visible'].pop()
|
||||
history['internal'].pop()
|
||||
text = last[0]
|
||||
if args.cai_chat:
|
||||
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
yield i
|
||||
else:
|
||||
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size):
|
||||
yield i
|
||||
|
||||
def remove_last_message(name1, name2):
|
||||
@ -775,7 +778,7 @@ if args.chat or args.cai_chat:
|
||||
with gr.Column():
|
||||
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size'])
|
||||
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
|
||||
|
||||
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
@ -813,7 +816,7 @@ if args.chat or args.cai_chat:
|
||||
if args.extensions is not None:
|
||||
create_extensions_block()
|
||||
|
||||
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size_slider]
|
||||
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider]
|
||||
if args.cai_chat:
|
||||
gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen"))
|
||||
gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream))
|
||||
@ -860,13 +863,13 @@ elif args.notebook:
|
||||
|
||||
max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
|
||||
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
|
||||
|
||||
if args.extensions is not None:
|
||||
create_extensions_block()
|
||||
|
||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream))
|
||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream))
|
||||
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
||||
|
||||
else:
|
||||
@ -883,7 +886,7 @@ else:
|
||||
with gr.Column():
|
||||
buttons["Stop"] = gr.Button("Stop")
|
||||
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus()
|
||||
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
|
||||
if args.extensions is not None:
|
||||
create_extensions_block()
|
||||
|
||||
@ -895,9 +898,9 @@ else:
|
||||
with gr.Tab('HTML'):
|
||||
html = gr.HTML()
|
||||
|
||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||
gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen"))
|
||||
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||
gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream))
|
||||
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
||||
|
||||
interface.queue()
|
||||
|
Loading…
Reference in New Issue
Block a user