Implement echo/suffix parameters

This commit is contained in:
oobabooga 2023-11-07 08:43:45 -08:00
parent cee099f131
commit 3d59346871
2 changed files with 8 additions and 6 deletions

View File

@ -349,8 +349,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
generate_params['stream'] = stream generate_params['stream'] = stream
requested_model = generate_params.pop('model') requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None) logprob_proc = generate_params.pop('logprob_proc', None)
# generate_params['suffix'] = body.get('suffix', generate_params['suffix']) suffix = body['suffix'] if body['suffix'] else ''
generate_params['echo'] = body.get('echo', generate_params['echo']) echo = body['echo']
if not stream: if not stream:
prompt_arg = body[prompt_str] prompt_arg = body[prompt_str]
@ -373,6 +373,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
except KeyError: except KeyError:
prompt = decode(prompt)[0] prompt = decode(prompt)[0]
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count total_prompt_token_count += token_count
@ -393,7 +394,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
respi = { respi = {
"index": idx, "index": idx,
"finish_reason": stop_reason, "finish_reason": stop_reason,
"text": answer, "text": prefix + answer + suffix,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
} }
@ -425,6 +426,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
else: else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
def text_streaming_chunk(content): def text_streaming_chunk(content):
@ -444,7 +446,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
return chunk return chunk
yield text_streaming_chunk('') yield text_streaming_chunk(prefix)
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -472,7 +474,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length" stop_reason = "length"
chunk = text_streaming_chunk('') chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = { chunk["usage"] = {
"prompt_tokens": token_count, "prompt_tokens": token_count,

View File

@ -57,7 +57,7 @@ class CompletionRequestParams(BaseModel):
suffix: str | None = None suffix: str | None = None
temperature: float | None = 1 temperature: float | None = 1
top_p: float | None = 1 top_p: float | None = 1
user: str | None = None user: str | None = Field(default=None, description="Unused parameter.")
class CompletionRequest(GenerationOptions, CompletionRequestParams): class CompletionRequest(GenerationOptions, CompletionRequestParams):