mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add shared.is_chat() function
This commit is contained in:
parent
b38ba230f4
commit
b0890a7925
@ -11,7 +11,7 @@ def get_prompt_by_name(name):
|
|||||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
if not shared.args.chat or shared.args.cai_chat:
|
if not shared.is_chat():
|
||||||
choices = ['None'] + list(df['Prompt name'])
|
choices = ['None'] + list(df['Prompt name'])
|
||||||
|
|
||||||
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
|
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
|
||||||
|
@ -74,7 +74,7 @@ def input_modifier(string):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Remove autoplay from the last reply
|
# Remove autoplay from the last reply
|
||||||
if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0:
|
if shared.is_chat() and len(shared.history['internal']) > 0:
|
||||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
||||||
|
|
||||||
shared.processing_message = "*Is recording a voice message...*"
|
shared.processing_message = "*Is recording a voice message...*"
|
||||||
|
@ -129,10 +129,12 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
# Provisional, this will be deleted later
|
# Provisional, this will be deleted later
|
||||||
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
|
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
|
||||||
for k in deprecated_dict:
|
for k in deprecated_dict:
|
||||||
if eval(f"args.{k}") != deprecated_dict[k][1]:
|
if eval(f"args.{k}") != deprecated_dict[k][1]:
|
||||||
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
||||||
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
|
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
|
||||||
|
|
||||||
|
def is_chat():
|
||||||
|
return any((args.chat, args.cai_chat))
|
||||||
|
@ -76,7 +76,7 @@ def fix_galactica(s):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
if 'galactica' in model_name.lower():
|
if 'galactica' in model_name.lower():
|
||||||
reply = fix_galactica(reply)
|
reply = fix_galactica(reply)
|
||||||
return reply, reply, generate_basic_html(reply)
|
return reply, reply, generate_basic_html(reply)
|
||||||
@ -109,7 +109,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
question = apply_extensions(question, "input")
|
question = apply_extensions(question, "input")
|
||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
print(f"\n\n{question}\n--------------------\n")
|
print(f"\n\n{question}\n--------------------\n")
|
||||||
@ -121,18 +121,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
if shared.args.no_stream:
|
if shared.args.no_stream:
|
||||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
||||||
output = original_question+reply
|
output = original_question+reply
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
else:
|
else:
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
yield formatted_outputs(question, shared.model_name)
|
yield formatted_outputs(question, shared.model_name)
|
||||||
|
|
||||||
# RWKV has proper streaming, which is very nice.
|
# RWKV has proper streaming, which is very nice.
|
||||||
# No need to generate 8 tokens at a time.
|
# No need to generate 8 tokens at a time.
|
||||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
||||||
output = original_question+reply
|
output = original_question+reply
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
@ -226,7 +226,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
yield formatted_outputs(original_question, shared.model_name)
|
yield formatted_outputs(original_question, shared.model_name)
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
for output in generator:
|
||||||
@ -235,7 +235,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
|
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
@ -253,7 +253,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
new_tokens = len(output) - len(original_input_ids[0])
|
new_tokens = len(output) - len(original_input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
|
|
||||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
|
@ -244,7 +244,7 @@ available_loras = get_available_loras()
|
|||||||
|
|
||||||
# Default extensions
|
# Default extensions
|
||||||
extensions_module.available_extensions = get_available_extensions()
|
extensions_module.available_extensions = get_available_extensions()
|
||||||
if shared.args.chat or shared.args.cai_chat:
|
if shared.is_chat():
|
||||||
for extension in shared.settings['chat_default_extensions']:
|
for extension in shared.settings['chat_default_extensions']:
|
||||||
shared.args.extensions = shared.args.extensions or []
|
shared.args.extensions = shared.args.extensions or []
|
||||||
if extension not in shared.args.extensions:
|
if extension not in shared.args.extensions:
|
||||||
@ -290,8 +290,8 @@ def create_interface():
|
|||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
extensions_module.load_extensions()
|
extensions_module.load_extensions()
|
||||||
|
|
||||||
with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
if shared.args.chat or shared.args.cai_chat:
|
if shared.is_chat():
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
||||||
|
Loading…
Reference in New Issue
Block a user