diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 646dee2d..c9e1b3d0 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -16,6 +16,7 @@ from extensions.openai.errors import InvalidRequestError from extensions.openai.utils import debug_msg from modules import shared from modules.chat import ( + get_stopping_strings, generate_chat_prompt, generate_chat_reply, load_character_memoized, @@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p # generation parameters generate_params = process_parameters(body, is_legacy=is_legacy) continue_ = body['continue_'] + impersonate = body['impersonate'] + if impersonate: + continue_ = False # Instruction template if body['instruction_template_str']: @@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p def chat_streaming_chunk(content): # begin streaming + role = 'user' if impersonate else 'assistant' chunk = { "id": cmpl_id, "object": object_type, @@ -302,7 +307,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p resp_list: [{ "index": 0, "finish_reason": None, - "delta": {'role': 'assistant', 'content': content}, + "delta": {'role': role, 'content': content}, }], } @@ -314,9 +319,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p return chunk # generate reply ####################################### - prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) + prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate) if prompt_only: - yield {'prompt': prompt} + if impersonate: + yield {'prompt': prompt + user_input} + else: + yield {'prompt': prompt} return debug_msg({'prompt': prompt, 'generate_params': generate_params}) @@ -324,14 +332,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if stream: yield chat_streaming_chunk('') - generator = generate_chat_reply( - user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) + if impersonate: + stopping_strings = get_stopping_strings(generate_params) + generator = generate_reply(prompt + user_input, generate_params, stopping_strings=stopping_strings, is_chat=True) + + else: + generator = generate_chat_reply( + user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) answer = '' seen_content = '' for a in generator: - answer = a['internal'][-1][1] + answer = a if impersonate else a['internal'][-1][1] if stream: len_seen = len(seen_content) new_content = answer[len_seen:] diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4015f6a1..6b91e9ea 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel): continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.") + impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.") + class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams): pass