mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
Add impersonate
feature to API /v1/chat/completions
This commit is contained in:
parent
c24966c591
commit
ce6a836b46
@ -16,6 +16,7 @@ from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
from modules.chat import (
|
||||
get_stopping_strings,
|
||||
generate_chat_prompt,
|
||||
generate_chat_reply,
|
||||
load_character_memoized,
|
||||
@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||
# generation parameters
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
continue_ = body['continue_']
|
||||
impersonate = body['impersonate']
|
||||
if impersonate:
|
||||
continue_ = False
|
||||
|
||||
# Instruction template
|
||||
if body['instruction_template_str']:
|
||||
@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
role = 'user' if impersonate else 'assistant'
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
@ -302,7 +307,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"delta": {'role': 'assistant', 'content': content},
|
||||
"delta": {'role': role, 'content': content},
|
||||
}],
|
||||
}
|
||||
|
||||
@ -314,9 +319,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||
return chunk
|
||||
|
||||
# generate reply #######################################
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate)
|
||||
if prompt_only:
|
||||
yield {'prompt': prompt}
|
||||
if impersonate:
|
||||
yield {'prompt': prompt + user_input}
|
||||
else:
|
||||
yield {'prompt': prompt}
|
||||
return
|
||||
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
@ -324,14 +332,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
if impersonate:
|
||||
stopping_strings = get_stopping_strings(generate_params)
|
||||
generator = generate_reply(prompt + user_input, generate_params, stopping_strings=stopping_strings, is_chat=True)
|
||||
|
||||
else:
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
answer = a if impersonate else a['internal'][-1][1]
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel):
|
||||
|
||||
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
||||
|
||||
impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.")
|
||||
|
||||
|
||||
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user