mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Add greed parameter
This commit is contained in:
parent
13f2688134
commit
2dfb999bf1
@ -84,7 +84,7 @@ def extract_message_from_reply(question, reply, current, other, check, extension
|
|||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1):
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
just_started = True
|
just_started = True
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
@ -112,30 +112,33 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||||||
prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
for reply in generate_reply(prompt, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
reply = ''
|
||||||
|
for i in range(greed):
|
||||||
|
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||||
|
|
||||||
# Extracting the reply
|
# Extracting the reply
|
||||||
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True)
|
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True)
|
||||||
visible_reply = apply_extensions(reply, "output")
|
visible_reply = apply_extensions(reply, "output")
|
||||||
if shared.args.chat:
|
if shared.args.chat:
|
||||||
visible_reply = visible_reply.replace('\n', '<br>')
|
visible_reply = visible_reply.replace('\n', '<br>')
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
# otherwise gradio gets confused
|
# otherwise gradio gets confused
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
return shared.history['visible']
|
return shared.history['visible']
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
shared.history['internal'].append(['', ''])
|
shared.history['internal'].append(['', ''])
|
||||||
shared.history['visible'].append(['', ''])
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
shared.history['internal'][-1] = [text, reply]
|
shared.history['internal'][-1] = [text, reply]
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||||
if not substring_found:
|
if not substring_found:
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
print(i, reply)
|
||||||
|
|
||||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
@ -153,11 +156,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
|||||||
break
|
break
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1):
|
||||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed):
|
||||||
yield generate_chat_html(_history, name1, name2, shared.character)
|
yield generate_chat_html(_history, name1, name2, shared.character)
|
||||||
|
|
||||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1):
|
||||||
if shared.character != 'None' and len(shared.history['visible']) == 1:
|
if shared.character != 'None' and len(shared.history['visible']) == 1:
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||||
@ -167,7 +170,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
|
|
||||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed):
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||||
|
@ -241,9 +241,10 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||||
|
shared.gradio['greed'] = gr.Slider(minimum=1, maximum=5, value=1, step=1)
|
||||||
create_settings_menus()
|
create_settings_menus()
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'greed']]
|
||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
with gr.Tab('Extensions'):
|
with gr.Tab('Extensions'):
|
||||||
extensions_module.create_extensions_block()
|
extensions_module.create_extensions_block()
|
||||||
|
Loading…
Reference in New Issue
Block a user