From 3d593468719838401bb0268b00e0bd23cf15d97c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Nov 2023 08:43:45 -0800 Subject: [PATCH] Implement echo/suffix parameters --- extensions/openai/completions.py | 12 +++++++----- extensions/openai/typing.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index f01282f2..1c0159e8 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -349,8 +349,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): generate_params['stream'] = stream requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) - # generate_params['suffix'] = body.get('suffix', generate_params['suffix']) - generate_params['echo'] = body.get('echo', generate_params['echo']) + suffix = body['suffix'] if body['suffix'] else '' + echo = body['echo'] if not stream: prompt_arg = body[prompt_str] @@ -373,6 +373,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): except KeyError: prompt = decode(prompt)[0] + prefix = prompt if echo else '' token_count = len(encode(prompt)[0]) total_prompt_token_count += token_count @@ -393,7 +394,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): respi = { "index": idx, "finish_reason": stop_reason, - "text": answer, + "text": prefix + answer + suffix, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, } @@ -425,6 +426,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): else: raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) + prefix = prompt if echo else '' token_count = len(encode(prompt)[0]) def text_streaming_chunk(content): @@ -444,7 +446,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): return chunk - yield text_streaming_chunk('') + yield text_streaming_chunk(prefix) # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) @@ -472,7 +474,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: stop_reason = "length" - chunk = text_streaming_chunk('') + chunk = text_streaming_chunk(suffix) chunk[resp_list][0]["finish_reason"] = stop_reason chunk["usage"] = { "prompt_tokens": token_count, diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index c9a3b30a..4d49803e 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -57,7 +57,7 @@ class CompletionRequestParams(BaseModel): suffix: str | None = None temperature: float | None = 1 top_p: float | None = 1 - user: str | None = None + user: str | None = Field(default=None, description="Unused parameter.") class CompletionRequest(GenerationOptions, CompletionRequestParams):