mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Merge branch 'main' into mcmonkey4eva-add-train-lora-tab
This commit is contained in:
commit
c2cad30772
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,3 +19,4 @@ repositories
|
|||||||
settings.json
|
settings.json
|
||||||
img_bot*
|
img_bot*
|
||||||
img_me*
|
img_me*
|
||||||
|
prompts/[0-9]*
|
||||||
|
11
css/main.css
11
css/main.css
@ -37,12 +37,6 @@
|
|||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
svg {
|
|
||||||
display: unset !important;
|
|
||||||
vertical-align: middle !important;
|
|
||||||
margin: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
ol li p, ul li p {
|
ol li p, ul li p {
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
@ -64,3 +58,8 @@ ol li p, ul li p {
|
|||||||
padding: 15px;
|
padding: 15px;
|
||||||
padding: 15px;
|
padding: 15px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
span.math.inline {
|
||||||
|
font-size: 27px;
|
||||||
|
vertical-align: baseline !important;
|
||||||
|
}
|
||||||
|
@ -54,7 +54,7 @@ class Iteratorize:
|
|||||||
self.stop_now = False
|
self.stop_now = False
|
||||||
|
|
||||||
def _callback(val):
|
def _callback(val):
|
||||||
if self.stop_now:
|
if self.stop_now or shared.stop_everything:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
self.q.put(val)
|
self.q.put(val)
|
||||||
|
|
||||||
|
@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
|
|||||||
reply = fix_newlines(reply)
|
reply = fix_newlines(reply)
|
||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
def stop_everything_event():
|
|
||||||
shared.stop_everything = True
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||||
shared.stop_everything = False
|
|
||||||
just_started = True
|
just_started = True
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
name1_original = name1
|
name1_original = name1
|
||||||
|
@ -99,9 +99,13 @@ def set_manual_seed(seed):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
def stop_everything_event():
|
||||||
|
shared.stop_everything = True
|
||||||
|
|
||||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(seed)
|
set_manual_seed(seed)
|
||||||
|
shared.stop_everything = False
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
break
|
break
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
|
||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(max_new_tokens//8+1):
|
for i in range(max_new_tokens//8+1):
|
||||||
|
23
server.py
23
server.py
@ -14,7 +14,8 @@ import modules.extensions as extensions_module
|
|||||||
from modules.html_generator import generate_chat_html
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt
|
||||||
from modules.text_generation import clear_torch_cache, generate_reply
|
from modules.text_generation import (clear_torch_cache, generate_reply,
|
||||||
|
stop_everything_event)
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
@ -133,7 +134,7 @@ def save_prompt(text):
|
|||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
||||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(text)
|
f.write(text)
|
||||||
return f"Saved prompt to prompts/{fname}"
|
return f"Saved to prompts/{fname}"
|
||||||
|
|
||||||
def load_prompt(fname):
|
def load_prompt(fname):
|
||||||
if fname in ['None', '']:
|
if fname in ['None', '']:
|
||||||
@ -154,7 +155,7 @@ def create_prompt_menus():
|
|||||||
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
||||||
shared.gradio['status'] = gr.Markdown('Ready')
|
shared.gradio['status'] = gr.Markdown('Ready')
|
||||||
|
|
||||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
|
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
@ -364,7 +365,7 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||||
@ -415,11 +416,15 @@ def create_interface():
|
|||||||
shared.gradio['html'] = gr.HTML()
|
shared.gradio['html'] = gr.HTML()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate'] = gr.Button('Generate')
|
with gr.Column():
|
||||||
shared.gradio['Stop'] = gr.Button('Stop')
|
with gr.Row():
|
||||||
|
shared.gradio['Generate'] = gr.Button('Generate')
|
||||||
|
shared.gradio['Stop'] = gr.Button('Stop')
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
gr.Markdown("\n")
|
gr.HTML('<div style="padding-bottom: 13px"></div>')
|
||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
|
|
||||||
create_prompt_menus()
|
create_prompt_menus()
|
||||||
@ -431,7 +436,7 @@ def create_interface():
|
|||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -465,7 +470,7 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
with gr.Tab("Training", elem_id="training-tab"):
|
with gr.Tab("Training", elem_id="training-tab"):
|
||||||
|
Loading…
Reference in New Issue
Block a user