Add impersonate feature to API /v1/chat/completions

This commit is contained in:
Yiximail 2024-08-22 15:56:37 +08:00
parent c24966c591
commit ce6a836b46
2 changed files with 21 additions and 6 deletions

View File

@ -16,6 +16,7 @@ from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg from extensions.openai.utils import debug_msg
from modules import shared from modules import shared
from modules.chat import ( from modules.chat import (
get_stopping_strings,
generate_chat_prompt, generate_chat_prompt,
generate_chat_reply, generate_chat_reply,
load_character_memoized, load_character_memoized,
@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# generation parameters # generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy) generate_params = process_parameters(body, is_legacy=is_legacy)
continue_ = body['continue_'] continue_ = body['continue_']
impersonate = body['impersonate']
if impersonate:
continue_ = False
# Instruction template # Instruction template
if body['instruction_template_str']: if body['instruction_template_str']:
@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
def chat_streaming_chunk(content): def chat_streaming_chunk(content):
# begin streaming # begin streaming
role = 'user' if impersonate else 'assistant'
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
@ -302,7 +307,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"delta": {'role': 'assistant', 'content': content}, "delta": {'role': role, 'content': content},
}], }],
} }
@ -314,8 +319,11 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
return chunk return chunk
# generate reply ####################################### # generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate)
if prompt_only: if prompt_only:
if impersonate:
yield {'prompt': prompt + user_input}
else:
yield {'prompt': prompt} yield {'prompt': prompt}
return return
@ -324,6 +332,11 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if stream: if stream:
yield chat_streaming_chunk('') yield chat_streaming_chunk('')
if impersonate:
stopping_strings = get_stopping_strings(generate_params)
generator = generate_reply(prompt + user_input, generate_params, stopping_strings=stopping_strings, is_chat=True)
else:
generator = generate_chat_reply( generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
@ -331,7 +344,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
seen_content = '' seen_content = ''
for a in generator: for a in generator:
answer = a['internal'][-1][1] answer = a if impersonate else a['internal'][-1][1]
if stream: if stream:
len_seen = len(seen_content) len_seen = len(seen_content)
new_content = answer[len_seen:] new_content = answer[len_seen:]

View File

@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel):
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.") continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.")
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams): class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
pass pass