mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
extensions/openai: +Array input (batched) , +Fixes (#3309)
This commit is contained in:
parent
40038fdb82
commit
9ae0eab989
@ -174,7 +174,7 @@ print(text)
|
||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
||||
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model |
|
||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
|
||||
@ -204,6 +204,7 @@ Some hacky mappings:
|
||||
| 1.0 | typical_p | hardcoded to 1.0 |
|
||||
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
||||
| messages.name | - | not supported yet |
|
||||
| suffix | - | not supported yet |
|
||||
| user | - | not supported yet |
|
||||
| functions/function_call | - | function calls are not supported yet |
|
||||
|
||||
|
@ -48,7 +48,7 @@ class LogprobProcessor(LogitsProcessor):
|
||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||
top_probs = [ float(x) for x in top_values[0] ]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
|
||||
debug_msg(repr(self))
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
@ -63,6 +63,7 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
# except KeyError:
|
||||
# # assume native tokens if we can't find the tokenizer
|
||||
# return logprobs
|
||||
return logprobs
|
||||
|
||||
|
||||
@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
if not prompt_str in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
@ -460,22 +454,38 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
@ -486,33 +496,33 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
}],
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
if logprob_proc and logprob_proc.token_alternatives:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
@ -572,13 +583,9 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}],
|
||||
}
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
chunk[resp_list][0]["logprobs"] = None
|
||||
|
||||
return chunk
|
||||
|
||||
@ -586,8 +593,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
|
@ -120,7 +120,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
else:
|
||||
model_name = self.path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info()
|
||||
resp = OAImodels.model_info(model_name)
|
||||
|
||||
self.return_json(resp)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user