mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Reorganize some chat functions
This commit is contained in:
parent
ec979cd9c4
commit
a453d4e9c4
@ -105,14 +105,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
visible_text = None
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
@ -124,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
visible_text = text
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
@ -135,8 +138,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
@ -175,6 +176,8 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
@ -184,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
cumulative_reply = ''
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
@ -264,7 +266,7 @@ def redraw_html(name1, name2, mode):
|
||||
|
||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
history = []
|
||||
|
||||
messages = []
|
||||
dialogue = re.sub('<START>', '', dialogue)
|
||||
dialogue = re.sub('<start>', '', dialogue)
|
||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||
@ -273,7 +275,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
if len(idx) == 0:
|
||||
return history
|
||||
|
||||
messages = []
|
||||
for i in range(len(idx) - 1):
|
||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
|
Loading…
Reference in New Issue
Block a user