mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Add impersonate
feature to API /v1/chat/completions
This commit is contained in:
parent
c24966c591
commit
ce6a836b46
@ -16,6 +16,7 @@ from extensions.openai.errors import InvalidRequestError
|
|||||||
from extensions.openai.utils import debug_msg
|
from extensions.openai.utils import debug_msg
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.chat import (
|
from modules.chat import (
|
||||||
|
get_stopping_strings,
|
||||||
generate_chat_prompt,
|
generate_chat_prompt,
|
||||||
generate_chat_reply,
|
generate_chat_reply,
|
||||||
load_character_memoized,
|
load_character_memoized,
|
||||||
@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||||||
# generation parameters
|
# generation parameters
|
||||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
continue_ = body['continue_']
|
continue_ = body['continue_']
|
||||||
|
impersonate = body['impersonate']
|
||||||
|
if impersonate:
|
||||||
|
continue_ = False
|
||||||
|
|
||||||
# Instruction template
|
# Instruction template
|
||||||
if body['instruction_template_str']:
|
if body['instruction_template_str']:
|
||||||
@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||||||
|
|
||||||
def chat_streaming_chunk(content):
|
def chat_streaming_chunk(content):
|
||||||
# begin streaming
|
# begin streaming
|
||||||
|
role = 'user' if impersonate else 'assistant'
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
@ -302,7 +307,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"delta": {'role': 'assistant', 'content': content},
|
"delta": {'role': role, 'content': content},
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,9 +319,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate)
|
||||||
if prompt_only:
|
if prompt_only:
|
||||||
yield {'prompt': prompt}
|
if impersonate:
|
||||||
|
yield {'prompt': prompt + user_input}
|
||||||
|
else:
|
||||||
|
yield {'prompt': prompt}
|
||||||
return
|
return
|
||||||
|
|
||||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
@ -324,14 +332,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||||||
if stream:
|
if stream:
|
||||||
yield chat_streaming_chunk('')
|
yield chat_streaming_chunk('')
|
||||||
|
|
||||||
generator = generate_chat_reply(
|
if impersonate:
|
||||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
stopping_strings = get_stopping_strings(generate_params)
|
||||||
|
generator = generate_reply(prompt + user_input, generate_params, stopping_strings=stopping_strings, is_chat=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
generator = generate_chat_reply(
|
||||||
|
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a['internal'][-1][1]
|
answer = a if impersonate else a['internal'][-1][1]
|
||||||
if stream:
|
if stream:
|
||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
new_content = answer[len_seen:]
|
new_content = answer[len_seen:]
|
||||||
|
@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel):
|
|||||||
|
|
||||||
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
||||||
|
|
||||||
|
impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
|
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user