mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Reorganize some chat functions
This commit is contained in:
parent
ec979cd9c4
commit
a453d4e9c4
@ -105,14 +105,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
else:
|
else:
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||||
|
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
# Defining some variables
|
||||||
|
cumulative_reply = ''
|
||||||
|
just_started = True
|
||||||
name1_original = name1
|
name1_original = name1
|
||||||
|
visible_text = custom_generate_chat_prompt = None
|
||||||
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
if 'pygmalion' in shared.model_name.lower():
|
||||||
name1 = "You"
|
name1 = "You"
|
||||||
|
|
||||||
# Check if any extension wants to hijack this function call
|
# Check if any extension wants to hijack this function call
|
||||||
visible_text = None
|
|
||||||
custom_generate_chat_prompt = None
|
|
||||||
for extension, _ in extensions_module.iterator():
|
for extension, _ in extensions_module.iterator():
|
||||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||||
extension.input_hijack['state'] = False
|
extension.input_hijack['state'] = False
|
||||||
@ -124,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
visible_text = text
|
visible_text = text
|
||||||
text = apply_extensions(text, "input")
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
|
# Generating the prompt
|
||||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
@ -135,8 +138,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
cumulative_reply = ''
|
|
||||||
just_started = True
|
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
for i in range(generate_state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
@ -175,6 +176,8 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
else:
|
else:
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||||
|
|
||||||
|
# Defining some variables
|
||||||
|
cumulative_reply = ''
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
if 'pygmalion' in shared.model_name.lower():
|
||||||
name1 = "You"
|
name1 = "You"
|
||||||
@ -184,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
yield shared.processing_message
|
yield shared.processing_message
|
||||||
|
|
||||||
cumulative_reply = ''
|
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
for i in range(generate_state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
@ -264,7 +266,7 @@ def redraw_html(name1, name2, mode):
|
|||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||||
history = []
|
history = []
|
||||||
|
messages = []
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
dialogue = re.sub('<start>', '', dialogue)
|
dialogue = re.sub('<start>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
@ -273,7 +275,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
if len(idx) == 0:
|
if len(idx) == 0:
|
||||||
return history
|
return history
|
||||||
|
|
||||||
messages = []
|
|
||||||
for i in range(len(idx) - 1):
|
for i in range(len(idx) - 1):
|
||||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||||
messages.append(dialogue[idx[-1]:].strip())
|
messages.append(dialogue[idx[-1]:].strip())
|
||||||
|
Loading…
Reference in New Issue
Block a user