Implement regenerate/impersonate the proper way (fixes #78)

This commit is contained in:
oobabooga 2023-02-15 14:39:26 -03:00
parent 5ee9283cae
commit b3bcd2881d

View File

@ -599,28 +599,26 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
if not substring_found: if not substring_found:
yield apply_extensions(reply, "output") yield reply
if next_character_found: if next_character_found:
break break
yield apply_extensions(reply, "output") yield reply
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield generate_chat_html(_history, name1, name2, character) yield generate_chat_html(_history, name1, name2, character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
last = history['visible'].pop() last_visible = history['visible'].pop()
last_internal = history['internal'].pop()
# Fix for when the last sent message was an image for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
if last[0].startswith('<img src="'): if args.cai_chat:
last[0] = history['internal'].pop()[0] history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(history['visible'], name1, name2, character)
else: else:
history['internal'].pop() history['visible'][-1] = (last_visible[0], _history[-1][1])
yield history['visible']
text = last[0]
function_call = "cai_chatbot_wrapper" if args.cai_chat else "chatbot_wrapper"
for i in eval(function_call)(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield i
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':