extensions/openai: +Array input (batched) , +Fixes (#3309)

This commit is contained in:
matatonic 2023-08-01 21:26:00 -04:00 committed by GitHub
parent 40038fdb82
commit 9ae0eab989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 64 deletions

View File

@ -174,7 +174,7 @@ print(text)
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for | | /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models | | /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model | | /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | | /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint | | /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint | | /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
@ -204,6 +204,7 @@ Some hacky mappings:
| 1.0 | typical_p | hardcoded to 1.0 | | 1.0 | typical_p | hardcoded to 1.0 |
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' | | logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| messages.name | - | not supported yet | | messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet | | user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet | | functions/function_call | - | function calls are not supported yet |

View File

@ -48,7 +48,7 @@ class LogprobProcessor(LogitsProcessor):
top_tokens = [ decode(tok) for tok in top_indices[0] ] top_tokens = [ decode(tok) for tok in top_indices[0] ]
top_probs = [ float(x) for x in top_values[0] ] top_probs = [ float(x) for x in top_values[0] ]
self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})") debug_msg(repr(self))
return logits return logits
def __repr__(self): def __repr__(self):
@ -63,7 +63,8 @@ def convert_logprobs_to_tiktoken(model, logprobs):
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError: # except KeyError:
# # assume native tokens if we can't find the tokenizer # # assume native tokens if we can't find the tokenizer
return logprobs # return logprobs
return logprobs
def marshal_common_params(body): def marshal_common_params(body):
@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
req_params['max_new_tokens'] = req_params['truncation_length'] req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages # format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors # set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
stopping_strings = req_params.pop('stopping_strings', [])
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params}) debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = '' answer = ''
@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = req_params['truncation_length'] req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages # format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors # set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False):
if not prompt_str in body: if not prompt_str in body:
raise InvalidRequestError("Missing required input", param=prompt_str) raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str] prompt_arg = body[prompt_str]
if isinstance(prompt, list): if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
if prompt and isinstance(prompt[0], int): prompt_arg = [prompt_arg]
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params # common params
req_params = marshal_common_params(body) req_params = marshal_common_params(body)
@ -460,59 +454,75 @@ def completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model') requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None) logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
token_count = len(encode(prompt)[0]) #req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k']) req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
# generate reply ####################################### resp_list_data = []
debug_msg({'prompt': prompt, 'req_params': req_params}) total_completion_token_count = 0
stopping_strings = req_params.pop('stopping_strings', []) total_prompt_token_count = 0
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = '' for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
for a in generator: token_count = len(encode(prompt)[0])
answer = a total_prompt_token_count += token_count
# strip extra leading space off new generated content if token_count + max_tokens > req_params['truncation_length']:
if answer and answer[0] == ' ': err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
answer = answer[1:] # print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
completion_token_count = len(encode(answer)[0]) # generate reply #######################################
stop_reason = "stop" debug_msg({'prompt': prompt, 'req_params': req_params})
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
stop_reason = "length" answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = { resp = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, # TODO: add Lora info? "model": shared.model_name, # TODO: add Lora info?
resp_list: [{ resp_list: resp_list_data,
"index": 0,
"finish_reason": stop_reason,
"text": answer,
}],
"usage": { "usage": {
"prompt_tokens": token_count, "prompt_tokens": total_prompt_token_count,
"completion_tokens": completion_token_count, "completion_tokens": total_completion_token_count,
"total_tokens": token_count + completion_token_count "total_tokens": total_prompt_token_count + total_completion_token_count
} }
} }
if logprob_proc and logprob_proc.token_alternatives:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
resp[resp_list][0]["logprobs"] = None
return resp return resp
@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model') requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None) logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
# print(f"Warning: ${err_msg}") # print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str) raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
def text_streaming_chunk(content): def text_streaming_chunk(content):
# begin streaming # begin streaming
chunk = { chunk = {
@ -572,13 +583,9 @@ def stream_completions(body: dict, is_legacy: bool = False):
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"text": content, "text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}], }],
} }
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
chunk[resp_list][0]["logprobs"] = None
return chunk return chunk
@ -586,8 +593,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params}) debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = '' answer = ''

View File

@ -120,7 +120,7 @@ class Handler(BaseHTTPRequestHandler):
resp = OAImodels.list_models(is_legacy) resp = OAImodels.list_models(is_legacy)
else: else:
model_name = self.path[len('/v1/models/'):] model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info() resp = OAImodels.model_info(model_name)
self.return_json(resp) self.return_json(resp)