extensions/openai: +Array input (batched) , +Fixes (#3309)

This commit is contained in:
matatonic 2023-08-01 21:26:00 -04:00 committed by GitHub
parent 40038fdb82
commit 9ae0eab989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 64 deletions

View File

@ -174,7 +174,7 @@ print(text)
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
@ -204,6 +204,7 @@ Some hacky mappings:
| 1.0 | typical_p | hardcoded to 1.0 |
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet |

View File

@ -48,7 +48,7 @@ class LogprobProcessor(LogitsProcessor):
top_tokens = [ decode(tok) for tok in top_indices[0] ]
top_probs = [ float(x) for x in top_values[0] ]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
debug_msg(repr(self))
return logits
def __repr__(self):
@ -63,7 +63,8 @@ def convert_logprobs_to_tiktoken(model, logprobs):
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
return logprobs
# return logprobs
return logprobs
def marshal_common_params(body):
@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
stopping_strings = req_params.pop('stopping_strings', [])
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False):
if not prompt_str in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
# common params
req_params = marshal_common_params(body)
@ -460,59 +454,75 @@ def completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
answer = ''
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
for a in generator:
answer = a
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"text": answer,
}],
resp_list: resp_list_data,
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
if logprob_proc and logprob_proc.token_alternatives:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
resp[resp_list][0]["logprobs"] = None
return resp
@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0])
@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
def text_streaming_chunk(content):
# begin streaming
chunk = {
@ -572,13 +583,9 @@ def stream_completions(body: dict, is_legacy: bool = False):
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
chunk[resp_list][0]["logprobs"] = None
return chunk
@ -586,8 +593,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''

View File

@ -120,7 +120,7 @@ class Handler(BaseHTTPRequestHandler):
resp = OAImodels.list_models(is_legacy)
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info()
resp = OAImodels.model_info(model_name)
self.return_json(resp)