From b0890a79257b4d2da7314a05b021f8c010f7ffbb Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 1 Apr 2023 20:14:43 -0300 Subject: [PATCH] Add shared.is_chat() function --- extensions/llama_prompts/script.py | 2 +- extensions/silero_tts/script.py | 2 +- modules/shared.py | 4 +++- modules/text_generation.py | 18 +++++++++--------- server.py | 6 +++--- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/extensions/llama_prompts/script.py b/extensions/llama_prompts/script.py index 22c96f7c..e40ac5c0 100644 --- a/extensions/llama_prompts/script.py +++ b/extensions/llama_prompts/script.py @@ -11,7 +11,7 @@ def get_prompt_by_name(name): return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') def ui(): - if not shared.args.chat or shared.args.cai_chat: + if not shared.is_chat(): choices = ['None'] + list(df['Prompt name']) prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt') diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 1352993a..6ee617c8 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -74,7 +74,7 @@ def input_modifier(string): """ # Remove autoplay from the last reply - if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0: + if shared.is_chat() and len(shared.history['internal']) > 0: shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')] shared.processing_message = "*Is recording a voice message...*" diff --git a/modules/shared.py b/modules/shared.py index 65ae7fb0..c4225586 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -129,10 +129,12 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent args = parser.parse_args() - # Provisional, this will be deleted later deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} for k in deprecated_dict: if eval(f"args.{k}") != deprecated_dict[k][1]: print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") exec(f"args.{deprecated_dict[k][0]} = args.{k}") + +def is_chat(): + return any((args.chat, args.cai_chat)) diff --git a/modules/text_generation.py b/modules/text_generation.py index 6ae592db..406c4548 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -76,7 +76,7 @@ def fix_galactica(s): return s def formatted_outputs(reply, model_name): - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): if 'galactica' in model_name.lower(): reply = fix_galactica(reply) return reply, reply, generate_basic_html(reply) @@ -109,7 +109,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t0 = time.time() original_question = question - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): question = apply_extensions(question, "input") if shared.args.verbose: print(f"\n\n{question}\n--------------------\n") @@ -121,18 +121,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.args.no_stream: reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) output = original_question+reply - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) else: - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): yield formatted_outputs(question, shared.model_name) # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty): output = original_question+reply - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -208,7 +208,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi new_tokens = len(output) - len(input_ids[0]) reply = decode(output[-new_tokens:]) - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -226,7 +226,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): yield formatted_outputs(original_question, shared.model_name) with generate_with_streaming(**generate_params) as generator: for output in generator: @@ -235,7 +235,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi new_tokens = len(output) - len(input_ids[0]) reply = decode(output[-new_tokens:]) - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") if output[-1] in eos_token_ids: @@ -253,7 +253,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi new_tokens = len(output) - len(original_input_ids[0]) reply = decode(output[-new_tokens:]) - if not (shared.args.chat or shared.args.cai_chat): + if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): diff --git a/server.py b/server.py index ebd9c81e..e1caf967 100644 --- a/server.py +++ b/server.py @@ -244,7 +244,7 @@ available_loras = get_available_loras() # Default extensions extensions_module.available_extensions = get_available_extensions() -if shared.args.chat or shared.args.cai_chat: +if shared.is_chat(): for extension in shared.settings['chat_default_extensions']: shared.args.extensions = shared.args.extensions or [] if extension not in shared.args.extensions: @@ -290,8 +290,8 @@ def create_interface(): if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() - with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: - if shared.args.chat or shared.args.cai_chat: + with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: + if shared.is_chat(): with gr.Tab("Text generation", elem_id="main"): if shared.args.cai_chat: shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))