mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
extensions/openai: +Array input (batched) , +Fixes (#3309)
This commit is contained in:
parent
40038fdb82
commit
9ae0eab989
@ -174,7 +174,7 @@ print(text)
|
|||||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
||||||
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
||||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model |
|
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
||||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||||
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
|
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||||
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
|
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
|
||||||
@ -204,6 +204,7 @@ Some hacky mappings:
|
|||||||
| 1.0 | typical_p | hardcoded to 1.0 |
|
| 1.0 | typical_p | hardcoded to 1.0 |
|
||||||
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
||||||
| messages.name | - | not supported yet |
|
| messages.name | - | not supported yet |
|
||||||
|
| suffix | - | not supported yet |
|
||||||
| user | - | not supported yet |
|
| user | - | not supported yet |
|
||||||
| functions/function_call | - | function calls are not supported yet |
|
| functions/function_call | - | function calls are not supported yet |
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class LogprobProcessor(LogitsProcessor):
|
|||||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||||
top_probs = [ float(x) for x in top_values[0] ]
|
top_probs = [ float(x) for x in top_values[0] ]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
|
debug_msg(repr(self))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -63,7 +63,8 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
|||||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||||
# except KeyError:
|
# except KeyError:
|
||||||
# # assume native tokens if we can't find the tokenizer
|
# # assume native tokens if we can't find the tokenizer
|
||||||
return logprobs
|
# return logprobs
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
def marshal_common_params(body):
|
def marshal_common_params(body):
|
||||||
@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
|||||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
# format the prompt from messages
|
# format the prompt from messages
|
||||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||||
|
|
||||||
# set real max, avoid deeper errors
|
# set real max, avoid deeper errors
|
||||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||||
|
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
# format the prompt from messages
|
# format the prompt from messages
|
||||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||||
|
|
||||||
# set real max, avoid deeper errors
|
# set real max, avoid deeper errors
|
||||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||||
@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False):
|
|||||||
if not prompt_str in body:
|
if not prompt_str in body:
|
||||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
prompt = body[prompt_str]
|
prompt_arg = body[prompt_str]
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||||
if prompt and isinstance(prompt[0], int):
|
prompt_arg = [prompt_arg]
|
||||||
try:
|
|
||||||
encoder = tiktoken.encoding_for_model(requested_model)
|
|
||||||
prompt = encoder.decode(prompt)
|
|
||||||
except KeyError:
|
|
||||||
prompt = decode(prompt)[0]
|
|
||||||
else:
|
|
||||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
|
||||||
|
|
||||||
# common params
|
# common params
|
||||||
req_params = marshal_common_params(body)
|
req_params = marshal_common_params(body)
|
||||||
@ -460,59 +454,75 @@ def completions(body: dict, is_legacy: bool = False):
|
|||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
requested_model = req_params.pop('requested_model')
|
requested_model = req_params.pop('requested_model')
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
token_count = len(encode(prompt)[0])
|
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
|
|
||||||
if token_count + max_tokens > req_params['truncation_length']:
|
|
||||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
|
||||||
# print(f"Warning: ${err_msg}")
|
|
||||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
|
||||||
|
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
# generate reply #######################################
|
resp_list_data = []
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
total_completion_token_count = 0
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
total_prompt_token_count = 0
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||||
|
if isinstance(prompt[0], int):
|
||||||
|
# token lists
|
||||||
|
if requested_model == shared.model_name:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encoder.decode(prompt)
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
|
||||||
for a in generator:
|
token_count = len(encode(prompt)[0])
|
||||||
answer = a
|
total_prompt_token_count += token_count
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
if token_count + max_tokens > req_params['truncation_length']:
|
||||||
if answer and answer[0] == ' ':
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||||
answer = answer[1:]
|
# print(f"Warning: ${err_msg}")
|
||||||
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
# generate reply #######################################
|
||||||
stop_reason = "stop"
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
stop_reason = "length"
|
answer = ''
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
total_completion_token_count += completion_token_count
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
respi = {
|
||||||
|
"index": idx,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"text": answer,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp_list_data.extend([respi])
|
||||||
|
|
||||||
resp = {
|
resp = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name, # TODO: add Lora info?
|
"model": shared.model_name, # TODO: add Lora info?
|
||||||
resp_list: [{
|
resp_list: resp_list_data,
|
||||||
"index": 0,
|
|
||||||
"finish_reason": stop_reason,
|
|
||||||
"text": answer,
|
|
||||||
}],
|
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": total_prompt_token_count,
|
||||||
"completion_tokens": completion_token_count,
|
"completion_tokens": total_completion_token_count,
|
||||||
"total_tokens": token_count + completion_token_count
|
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if logprob_proc and logprob_proc.token_alternatives:
|
|
||||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
|
||||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
|
||||||
else:
|
|
||||||
resp[resp_list][0]["logprobs"] = None
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
requested_model = req_params.pop('requested_model')
|
requested_model = req_params.pop('requested_model')
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
# print(f"Warning: ${err_msg}")
|
# print(f"Warning: ${err_msg}")
|
||||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
|
||||||
|
|
||||||
def text_streaming_chunk(content):
|
def text_streaming_chunk(content):
|
||||||
# begin streaming
|
# begin streaming
|
||||||
chunk = {
|
chunk = {
|
||||||
@ -572,13 +583,9 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"text": content,
|
"text": content,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
if logprob_proc:
|
|
||||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
|
||||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
|
||||||
else:
|
|
||||||
chunk[resp_list][0]["logprobs"] = None
|
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
@ -586,8 +593,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
|
@ -120,7 +120,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
resp = OAImodels.list_models(is_legacy)
|
resp = OAImodels.list_models(is_legacy)
|
||||||
else:
|
else:
|
||||||
model_name = self.path[len('/v1/models/'):]
|
model_name = self.path[len('/v1/models/'):]
|
||||||
resp = OAImodels.model_info()
|
resp = OAImodels.model_info(model_name)
|
||||||
|
|
||||||
self.return_json(resp)
|
self.return_json(resp)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user