Add impersonate feature to API /v1/chat/completions

This commit is contained in:
Yiximail 2024-08-22 15:56:37 +08:00
parent c24966c591
commit ce6a836b46
2 changed files with 21 additions and 6 deletions

View File

@ -16,6 +16,7 @@ from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from modules import shared
from modules.chat import (
get_stopping_strings,
generate_chat_prompt,
generate_chat_reply,
load_character_memoized,
@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
continue_ = body['continue_']
impersonate = body['impersonate']
if impersonate:
continue_ = False
# Instruction template
if body['instruction_template_str']:
@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
def chat_streaming_chunk(content):
# begin streaming
role = 'user' if impersonate else 'assistant'
chunk = {
"id": cmpl_id,
"object": object_type,
@ -302,7 +307,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": {'role': 'assistant', 'content': content},
"delta": {'role': role, 'content': content},
}],
}
@ -314,9 +319,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
return chunk
# generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate)
if prompt_only:
yield {'prompt': prompt}
if impersonate:
yield {'prompt': prompt + user_input}
else:
yield {'prompt': prompt}
return
debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -324,14 +332,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if stream:
yield chat_streaming_chunk('')
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
if impersonate:
stopping_strings = get_stopping_strings(generate_params)
generator = generate_reply(prompt + user_input, generate_params, stopping_strings=stopping_strings, is_chat=True)
else:
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = ''
seen_content = ''
for a in generator:
answer = a['internal'][-1][1]
answer = a if impersonate else a['internal'][-1][1]
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]

View File

@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel):
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.")
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
pass