diff --git a/server.py b/server.py index bfebdc33..a540ca45 100644 --- a/server.py +++ b/server.py @@ -597,6 +597,9 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, yield history['visible'] def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + if 'pygmalion' in model_name.lower(): + name1 = "You" + question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):