mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
extensions/openai: Major openai extension updates & fixes (#3049)
* many openai updates * total reorg & cleanup. * fixups * missing import os for images * +moderations, custom_stopping_strings, more fixes * fix bugs in completion streaming * moderation fix (flagged) * updated moderation categories --------- Co-authored-by: Matthew Ashton <mashton-gitlab@zhero.org>
This commit is contained in:
parent
8db7e857b1
commit
3e7feb699c
@ -218,12 +218,11 @@ but there are some exceptions.
|
||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||
|
||||
## Future plans
|
||||
* better error handling
|
||||
* model changing, esp. something for swapping loras or embedding models
|
||||
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
|
||||
|
||||
## Bugs? Feedback? Comments? Pull requests?
|
||||
|
||||
|
599
extensions/openai/completions.py
Normal file
599
extensions/openai/completions.py
Normal file
@ -0,0 +1,599 @@
|
||||
import time
|
||||
import yaml
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, decode, generate_reply
|
||||
|
||||
from extensions.openai.defaults import get_default_req_params, default, clamp
|
||||
from extensions.openai.utils import end_line, debug_msg
|
||||
from extensions.openai.errors import *
|
||||
|
||||
|
||||
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
def __init__(self, logit_bias={}):
|
||||
self.logit_bias = logit_bias
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logit_bias:
|
||||
keys = list([int(key) for key in self.logit_bias.keys()])
|
||||
values = list([int(val) for val in self.logit_bias.values()])
|
||||
logits[0, keys] += torch.tensor(values).cuda()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class LogprobProcessor(LogitsProcessor):
|
||||
def __init__(self, logprobs=None):
|
||||
self.logprobs = logprobs
|
||||
self.token_alternatives = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
# XXX hack. should find the selected token and include the prob of that
|
||||
# ... but we just +1 here instead because we don't know it yet.
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||
return logits
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model)
|
||||
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
return dict([ (encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items() ])
|
||||
except KeyError:
|
||||
# assume native tokens if we can't find the tokenizer
|
||||
return logprobs
|
||||
|
||||
|
||||
def marshal_common_params(body):
|
||||
# Request Parameters
|
||||
# Try to use openai defaults or map them to something with the same intent
|
||||
|
||||
req_params = get_default_req_params()
|
||||
|
||||
# Common request parameters
|
||||
req_params['truncation_length'] = shared.settings['truncation_length']
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
# OpenAI API Parameters
|
||||
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||
|
||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||
if isinstance(body['stop'], str):
|
||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||
elif isinstance(body['stop'], list):
|
||||
req_params['stopping_strings'] = body['stop']
|
||||
|
||||
# presence_penalty - ignored
|
||||
# frequency_penalty - ignored
|
||||
# user - ignored
|
||||
|
||||
logits_processor = []
|
||||
logit_bias = body.get('logit_bias', None)
|
||||
if logit_bias: # {str: float, ...}
|
||||
# XXX convert tokens from tiktoken based on requested model
|
||||
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||
new_logit_bias = {}
|
||||
for logit, bias in logit_bias.items():
|
||||
for x in encode(encoder.decode([int(logit)]))[0]:
|
||||
new_logit_bias[str(int(x))] = bias
|
||||
print(logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
except KeyError:
|
||||
pass # assume native tokens if we can't find the tokenizer
|
||||
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([req_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
||||
return req_params
|
||||
|
||||
|
||||
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
# functions
|
||||
if body.get('functions', []): # chat only
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||
|
||||
if not 'messages' in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
'user': 'user: {message}\n',
|
||||
'assistant': 'assistant: {message}\n',
|
||||
'system': '{message}',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||
'prompt': 'assistant:',
|
||||
}
|
||||
|
||||
if not 'stopping_strings' in req_params:
|
||||
req_params['stopping_strings'] = []
|
||||
|
||||
# Instruct models can be much better
|
||||
if shared.settings['instruction_template']:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
'system': system_message_template,
|
||||
'context': system_message_default,
|
||||
'prompt': bot_prompt,
|
||||
}
|
||||
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
req_params['stopping_strings'].extend(['\n###'])
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
chat_msgs = []
|
||||
|
||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||
context_msg = end_line(context_msg)
|
||||
|
||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||
if 'prompt' in body:
|
||||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||
|
||||
for m in messages:
|
||||
role = m['role']
|
||||
content = m['content']
|
||||
# name = m.get('name', None)
|
||||
# function_call = m.get('function_call', None) # user name or function name with output in content
|
||||
msg = role_formats[role].format(message=content)
|
||||
if role == 'system':
|
||||
system_msgs.extend([msg])
|
||||
elif role == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
else:
|
||||
chat_msgs.extend([msg])
|
||||
|
||||
system_msg = '\n'.join(system_msgs)
|
||||
system_msg = end_line(system_msg)
|
||||
|
||||
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count >= req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||
raise InvalidRequestError(message=err_msg)
|
||||
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
#raise InvalidRequestError(message=err_msg)
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
# Chat Completions
|
||||
object_type = 'chat.completions'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
# Chat Completions
|
||||
stream_object_type = 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = True
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
# So yeah... do both methods? delta and messages.
|
||||
"message": {'role': 'assistant', 'content': content},
|
||||
"delta": {'role': 'assistant', 'content': content},
|
||||
}],
|
||||
}
|
||||
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
#else:
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
return chunk
|
||||
|
||||
yield chat_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool=False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
object_type = 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
if not prompt_str in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
#print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_completions(body: dict, is_legacy: bool=False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
#object_type = 'text_completion'
|
||||
stream_object_type = 'text_completion.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
if not prompt_str in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = True
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
#print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
}],
|
||||
}
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
chunk[resp_list][0]["logprobs"] = None
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
yield chunk
|
||||
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
64
extensions/openai/defaults.py
Normal file
64
extensions/openai/defaults.py
Normal file
@ -0,0 +1,64 @@
|
||||
import copy
|
||||
|
||||
# Slightly different defaults for OpenAI's API
|
||||
# Data type is important, Ex. use 0.0 for a float 0
|
||||
default_req_params = {
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'suffix': None,
|
||||
'stream': False,
|
||||
'echo': False,
|
||||
'seed': -1,
|
||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||
'truncation_length': 2048, # first use shared.settings value
|
||||
'add_bos_token': True,
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||
'eta_cutoff': 0.0, # In units of 1e-4
|
||||
'tfs': 1.0,
|
||||
'top_a': 0.0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0.0,
|
||||
'length_penalty': 1.0,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': '',
|
||||
# 'logits_processor' - conditionally passed
|
||||
# 'stopping_strings' - temporarily used
|
||||
# 'logprobs' - temporarily used
|
||||
# 'requested_model' - temporarily used
|
||||
}
|
||||
|
||||
def get_default_req_params():
|
||||
return copy.deepcopy(default_req_params)
|
||||
|
||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
def default(dic, key, default):
|
||||
val = dic.get(key, default)
|
||||
if type(val) != type(default):
|
||||
# maybe it's just something like 1 instead of 1.0
|
||||
try:
|
||||
v = type(default)(val)
|
||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
||||
return v
|
||||
except:
|
||||
pass
|
||||
|
||||
val = default
|
||||
return val
|
||||
|
||||
def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
||||
|
102
extensions/openai/edits.py
Normal file
102
extensions/openai/edits.py
Normal file
@ -0,0 +1,102 @@
|
||||
import time
|
||||
import yaml
|
||||
import os
|
||||
from modules import shared
|
||||
from extensions.openai.defaults import get_default_req_params
|
||||
from extensions.openai.utils import debug_msg
|
||||
from extensions.openai.errors import *
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
|
||||
created_time = int(time.time()*1000)
|
||||
|
||||
# Request parameters
|
||||
req_params = get_default_req_params()
|
||||
stopping_strings = []
|
||||
|
||||
# Alpaca is verbose so a good default prompt
|
||||
default_template = (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
)
|
||||
|
||||
instruction_template = default_template
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if shared.settings['instruction_template']:
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
stopping_strings.extend(['\n###'])
|
||||
else:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
template = template\
|
||||
.replace('<|user|>', instruct.get('user', ''))\
|
||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
else:
|
||||
stopping_strings.extend(['\n###'])
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
|
||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||
|
||||
truncation_length = shared.settings['truncation_length']
|
||||
|
||||
token_count = len(encode(edit_task)[0])
|
||||
max_tokens = truncation_length - token_count
|
||||
|
||||
if max_tokens < 1:
|
||||
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||
raise InvalidRequestError(err_msg, param='input')
|
||||
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = temperature
|
||||
req_params['top_p'] = top_p
|
||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
|
||||
resp = {
|
||||
"object": "edit",
|
||||
"created": created_time,
|
||||
"choices": [{
|
||||
"text": answer,
|
||||
"index": 0,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
50
extensions/openai/embeddings.py
Normal file
50
extensions/openai/embeddings.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from extensions.openai.utils import float_list_to_base64, debug_msg
|
||||
from extensions.openai.errors import *
|
||||
|
||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
|
||||
def load_embedding_model(model):
|
||||
try:
|
||||
emb_model = SentenceTransformer(model)
|
||||
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||
except Exception as e:
|
||||
print(f"\nError: Failed to load embedding model: {model}")
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message = repr(e))
|
||||
|
||||
return emb_model
|
||||
|
||||
def get_embeddings_model():
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
return embeddings_model
|
||||
|
||||
def get_embeddings_model_name():
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
def embeddings(input: list, encoding_format: str):
|
||||
|
||||
embeddings = get_embeddings_model().encode(input).tolist()
|
||||
|
||||
if encoding_format == "base64":
|
||||
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||
else:
|
||||
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||
|
||||
response = {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": st_model, # return the real model
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
}
|
||||
|
||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
|
||||
return response
|
27
extensions/openai/errors.py
Normal file
27
extensions/openai/errors.py
Normal file
@ -0,0 +1,27 @@
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, message = None, code = 500, internal_message = ''):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.internal_message = internal_message
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d)" % (
|
||||
self.__class__.__name__,
|
||||
self.message,
|
||||
self.code,
|
||||
)
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
def __init__(self, message, param, code = 400, error_type ='InvalidRequestError', internal_message = ''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
self.param = param
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d, param=%s)" % (
|
||||
self.__class__.__name__,
|
||||
self.message,
|
||||
self.code,
|
||||
self.param,
|
||||
)
|
||||
|
||||
class ServiceUnavailableError(OpenAIError):
|
||||
def __init__(self, message = None, code = 500, error_type ='ServiceUnavailableError', internal_message = ''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
48
extensions/openai/images.py
Normal file
48
extensions/openai/images.py
Normal file
@ -0,0 +1,48 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from extensions.openai.errors import *
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# Stable Diffusion callout wrapper for txt2img
|
||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
||||
# Will probably work best with the stock SD models.
|
||||
# SD configuration is beyond the scope of this API.
|
||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||
# require changing the form data handling to accept multipart form data, also to properly support
|
||||
# url return types will require file management and a web serving files... Perhaps later!
|
||||
|
||||
width, height = [ int(x) for x in size.split('x') ] # ignore the restrictions on size
|
||||
|
||||
# to hack on better generation, edit default payload.
|
||||
payload = {
|
||||
'prompt': prompt, # ignore prompt limit of 1000 characters
|
||||
'width': width,
|
||||
'height': height,
|
||||
'batch_size': n,
|
||||
'restore_faces': True, # slightly less horrible
|
||||
}
|
||||
|
||||
resp = {
|
||||
'created': int(time.time()),
|
||||
'data': []
|
||||
}
|
||||
|
||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
||||
|
||||
response = requests.post(url=sd_url, json=payload)
|
||||
r = response.json()
|
||||
if response.status_code != 200 or 'images' not in r:
|
||||
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code = response.status_code)
|
||||
# r['parameters']...
|
||||
for b64_json in r['images']:
|
||||
if response_format == 'b64_json':
|
||||
resp['data'].extend([{'b64_json': b64_json}])
|
||||
else:
|
||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||
|
||||
return resp
|
77
extensions/openai/models.py
Normal file
77
extensions/openai/models.py
Normal file
@ -0,0 +1,77 @@
|
||||
from modules import shared
|
||||
from modules.utils import get_available_models
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (get_model_settings_from_yamls,
|
||||
update_model_parameters)
|
||||
|
||||
from extensions.openai.embeddings import get_embeddings_model_name
|
||||
from extensions.openai.errors import *
|
||||
|
||||
def get_current_model_list() -> list:
|
||||
return [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||||
|
||||
def get_pseudo_model_list() -> list:
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
def load_model(model_name: str) -> dict:
|
||||
resp = {
|
||||
"id": model_name,
|
||||
"object": "engine",
|
||||
"owner": "self",
|
||||
"ready": True,
|
||||
}
|
||||
if model_name not in get_pseudo_model_list() + [ get_embeddings_model_name() ] + get_current_model_list(): # Real model only
|
||||
# No args. Maybe it works anyways!
|
||||
# TODO: hack some heuristics into args for better results
|
||||
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||
shared.settings.update(model_settings)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
if shared.settings['mode'] != 'instruct':
|
||||
shared.settings['instruction_template'] = None
|
||||
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
if not shared.model: # load failed.
|
||||
shared.model_name = "None"
|
||||
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def list_models(is_legacy: bool = False) -> dict:
|
||||
# TODO: Lora's?
|
||||
all_model_list = get_current_model_list() + [ get_embeddings_model_name() ] + get_pseudo_model_list() + get_available_models()
|
||||
|
||||
models = {}
|
||||
|
||||
if is_legacy:
|
||||
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||||
if not shared.model:
|
||||
models[0]['ready'] = False
|
||||
else:
|
||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||
|
||||
resp = {
|
||||
"object": "list",
|
||||
"data": models,
|
||||
}
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def model_info(model_name: str) -> dict:
|
||||
return {
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"owned_by": "user",
|
||||
"permission": []
|
||||
}
|
||||
|
70
extensions/openai/moderations.py
Normal file
70
extensions/openai/moderations.py
Normal file
@ -0,0 +1,70 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
from extensions.openai.embeddings import get_embeddings_model
|
||||
|
||||
|
||||
moderations_disabled = False # return 0/false
|
||||
category_embeddings = None
|
||||
antonym_embeddings = None
|
||||
categories = [ "sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence" ]
|
||||
flag_threshold = 0.5
|
||||
|
||||
|
||||
def get_category_embeddings():
|
||||
global category_embeddings, categories
|
||||
if category_embeddings is None:
|
||||
embeddings = get_embeddings_model().encode(categories).tolist()
|
||||
category_embeddings = dict(zip(categories, embeddings))
|
||||
|
||||
return category_embeddings
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (norm(a) * norm(b))
|
||||
|
||||
|
||||
# seems most openai like with all-mpnet-base-v2
|
||||
def mod_score(a, b):
|
||||
return 2.0 * np.dot(a, b)
|
||||
|
||||
|
||||
def moderations(input):
|
||||
global category_embeddings, categories, flag_threshold, moderations_disabled
|
||||
results = {
|
||||
"id": f"modr-{int(time.time()*1e9)}",
|
||||
"model": "text-moderation-001",
|
||||
"results": [],
|
||||
}
|
||||
|
||||
embeddings_model = get_embeddings_model()
|
||||
if not embeddings_model or moderations_disabled:
|
||||
results['results'] = [{
|
||||
'categories': dict([ (C, False) for C in categories]),
|
||||
'category_scores': dict([ (C, 0.0) for C in categories]),
|
||||
'flagged': False,
|
||||
}]
|
||||
return results
|
||||
|
||||
category_embeddings = get_category_embeddings()
|
||||
|
||||
|
||||
# input, string or array
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
for in_str in input:
|
||||
for ine in embeddings_model.encode([in_str]).tolist():
|
||||
category_scores = dict([ (C, mod_score(category_embeddings[C], ine)) for C in categories ])
|
||||
category_flags = dict([ (C, bool(category_scores[C] > flag_threshold)) for C in categories ])
|
||||
flagged = any(category_flags.values())
|
||||
|
||||
results['results'].extend([{
|
||||
'flagged': flagged,
|
||||
'categories': category_flags,
|
||||
'category_scores': category_scores,
|
||||
}])
|
||||
|
||||
print(results)
|
||||
|
||||
return results
|
@ -1,2 +1,3 @@
|
||||
flask_cloudflared==0.0.12
|
||||
sentence-transformers
|
||||
sentence-transformers
|
||||
tiktoken
|
File diff suppressed because it is too large
Load Diff
37
extensions/openai/tokens.py
Normal file
37
extensions/openai/tokens.py
Normal file
@ -0,0 +1,37 @@
|
||||
from extensions.openai.utils import float_list_to_base64
|
||||
from modules.text_generation import encode, decode
|
||||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
def token_encode(input, encoding_format = ''):
|
||||
#if isinstance(input, list):
|
||||
tokens = encode(input)[0]
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'encoding_format': encoding_format,
|
||||
'tokens': float_list_to_base64(tokens) if encoding_format == "base64" else tokens,
|
||||
'length': len(tokens),
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
def token_decode(tokens, encoding_format):
|
||||
#if isinstance(input, list):
|
||||
# if encoding_format == "base64":
|
||||
# tokens = base64_to_float_list(tokens)
|
||||
output = decode(tokens)[0]
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'text': output
|
||||
}]
|
||||
}
|
26
extensions/openai/utils.py
Normal file
26
extensions/openai/utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
import os
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
def float_list_to_base64(float_list):
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
float_array = np.array(float_list, dtype="float32")
|
||||
|
||||
# Get raw bytes
|
||||
bytes_array = float_array.tobytes()
|
||||
|
||||
# Encode bytes into base64
|
||||
encoded_bytes = base64.b64encode(bytes_array)
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
def end_line(s):
|
||||
if s and s[-1] != '\n':
|
||||
s = s + '\n'
|
||||
return s
|
||||
|
||||
def debug_msg(*args, **kwargs):
|
||||
if 'OPENEDAI_DEBUG' in os.environ:
|
||||
print(*args, **kwargs)
|
Loading…
Reference in New Issue
Block a user