mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Merge branch 'main' into mcmonkey4eva-add-train-lora-tab
This commit is contained in:
commit
c2cad30772
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,3 +19,4 @@ repositories
|
||||
settings.json
|
||||
img_bot*
|
||||
img_me*
|
||||
prompts/[0-9]*
|
||||
|
11
css/main.css
11
css/main.css
@ -37,12 +37,6 @@
|
||||
text-decoration: none !important;
|
||||
}
|
||||
|
||||
svg {
|
||||
display: unset !important;
|
||||
vertical-align: middle !important;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
ol li p, ul li p {
|
||||
display: inline-block;
|
||||
}
|
||||
@ -64,3 +58,8 @@ ol li p, ul li p {
|
||||
padding: 15px;
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
span.math.inline {
|
||||
font-size: 27px;
|
||||
vertical-align: baseline !important;
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ class Iteratorize:
|
||||
self.stop_now = False
|
||||
|
||||
def _callback(val):
|
||||
if self.stop_now:
|
||||
if self.stop_now or shared.stop_everything:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
|
@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
|
||||
reply = fix_newlines(reply)
|
||||
return reply, next_character_found
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
shared.stop_everything = False
|
||||
just_started = True
|
||||
eos_token = '\n' if check else None
|
||||
name1_original = name1
|
||||
|
@ -99,9 +99,13 @@ def set_manual_seed(seed):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
set_manual_seed(seed)
|
||||
shared.stop_everything = False
|
||||
t0 = time.time()
|
||||
|
||||
original_question = question
|
||||
@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(max_new_tokens//8+1):
|
||||
|
23
server.py
23
server.py
@ -14,7 +14,8 @@ import modules.extensions as extensions_module
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import clear_torch_cache, generate_reply
|
||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
||||
stop_everything_event)
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
@ -133,7 +134,7 @@ def save_prompt(text):
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
return f"Saved prompt to prompts/{fname}"
|
||||
return f"Saved to prompts/{fname}"
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
@ -154,7 +155,7 @@ def create_prompt_menus():
|
||||
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
||||
shared.gradio['status'] = gr.Markdown('Ready')
|
||||
|
||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
|
||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
@ -364,7 +365,7 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
@ -415,11 +416,15 @@ def create_interface():
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
with gr.Column():
|
||||
pass
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("\n")
|
||||
gr.HTML('<div style="padding-bottom: 13px"></div>')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
|
||||
create_prompt_menus()
|
||||
@ -431,7 +436,7 @@ def create_interface():
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
else:
|
||||
@ -465,7 +470,7 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
with gr.Tab("Training", elem_id="training-tab"):
|
||||
|
Loading…
Reference in New Issue
Block a user