mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
773 lines
32 KiB
Python
773 lines
32 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import time
|
|
import requests
|
|
import yaml
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from threading import Thread
|
|
|
|
import numpy as np
|
|
|
|
from modules import shared
|
|
from modules.text_generation import encode, generate_reply
|
|
|
|
params = {
|
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
|
}
|
|
|
|
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
|
|
|
# Optional, install the module and download the model to enable
|
|
# v1/embeddings
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
except ImportError:
|
|
pass
|
|
|
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
|
embedding_model = None
|
|
|
|
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
|
|
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
|
def default(dic, key, default):
|
|
val = dic.get(key, default)
|
|
if type(val) != type(default):
|
|
# maybe it's just something like 1 instead of 1.0
|
|
try:
|
|
v = type(default)(val)
|
|
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
|
return v
|
|
except:
|
|
pass
|
|
|
|
val = default
|
|
return val
|
|
|
|
|
|
def clamp(value, minvalue, maxvalue):
|
|
return max(minvalue, min(value, maxvalue))
|
|
|
|
|
|
def deduce_template():
|
|
# Alpaca is verbose so a good default prompt
|
|
default_template = (
|
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
"Write a response that appropriately completes the request.\n\n"
|
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
)
|
|
|
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
|
|
return default_template
|
|
|
|
try:
|
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
|
|
template = instruct['turn_template']
|
|
template = template\
|
|
.replace('<|user|>', instruct.get('user', ''))\
|
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
except:
|
|
return default_template
|
|
|
|
|
|
def float_list_to_base64(float_list):
|
|
# Convert the list to a float32 array that the OpenAPI client expects
|
|
float_array = np.array(float_list, dtype="float32")
|
|
|
|
# Get raw bytes
|
|
bytes_array = float_array.tobytes()
|
|
|
|
# Encode bytes into base64
|
|
encoded_bytes = base64.b64encode(bytes_array)
|
|
|
|
# Turn raw base64 encoded bytes into ASCII
|
|
ascii_string = encoded_bytes.decode('ascii')
|
|
return ascii_string
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
if self.path.startswith('/v1/models'):
|
|
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
# TODO: list all models and allow model changes via API? Lora's?
|
|
# This API should list capabilities, limits and pricing...
|
|
models = [{
|
|
"id": shared.model_name, # The real chat/completions model
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
}, {
|
|
"id": st_model, # The real sentence transformer embeddings model
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
}, { # these are expected by so much, so include some here as a dummy
|
|
"id": "gpt-3.5-turbo", # /v1/chat/completions
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
}, {
|
|
"id": "text-curie-001", # /v1/completions, 2k context
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
}, {
|
|
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
}]
|
|
|
|
response = ''
|
|
if self.path == '/v1/models':
|
|
response = json.dumps({
|
|
"object": "list",
|
|
"data": models,
|
|
})
|
|
else:
|
|
the_model_name = self.path[len('/v1/models/'):]
|
|
response = json.dumps({
|
|
"id": the_model_name,
|
|
"object": "model",
|
|
"owned_by": "user",
|
|
"permission": []
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
elif '/billing/usage' in self.path:
|
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
response = json.dumps({
|
|
"total_usage": 0,
|
|
})
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_POST(self):
|
|
if debug:
|
|
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
|
content_length = int(self.headers['Content-Length'])
|
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
|
|
|
if debug:
|
|
print(body)
|
|
|
|
if '/completions' in self.path or '/generate' in self.path:
|
|
is_legacy = '/generate' in self.path
|
|
is_chat = 'chat' in self.path
|
|
resp_list = 'data' if is_legacy else 'choices'
|
|
|
|
# XXX model is ignored for now
|
|
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
|
model = shared.model_name
|
|
created_time = int(time.time())
|
|
|
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
|
|
|
# Try to use openai defaults or map them to something with the same intent
|
|
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
|
|
if 'stop' in body:
|
|
if isinstance(body['stop'], str):
|
|
stopping_strings = [body['stop']]
|
|
elif isinstance(body['stop'], list):
|
|
stopping_strings = body['stop']
|
|
|
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
|
|
|
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
|
|
|
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
|
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
|
|
|
req_params = {
|
|
'max_new_tokens': max_tokens,
|
|
'temperature': default(body, 'temperature', 1.0),
|
|
'top_p': default(body, 'top_p', 1.0),
|
|
'top_k': default(body, 'best_of', 1),
|
|
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
|
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
|
|
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
|
|
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
|
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
|
|
'suffix': body.get('suffix', None),
|
|
'stream': default(body, 'stream', False),
|
|
'echo': default(body, 'echo', False),
|
|
#####################################################
|
|
'seed': shared.settings.get('seed', -1),
|
|
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
|
|
# unofficial, but it needs to get set anyways.
|
|
'truncation_length': truncation_length,
|
|
# no more args.
|
|
'add_bos_token': shared.settings.get('add_bos_token', True),
|
|
'do_sample': True,
|
|
'typical_p': 1.0,
|
|
'epsilon_cutoff': 0, # In units of 1e-4
|
|
'eta_cutoff': 0, # In units of 1e-4
|
|
'min_length': 0,
|
|
'no_repeat_ngram_size': 0,
|
|
'num_beams': 1,
|
|
'penalty_alpha': 0.0,
|
|
'length_penalty': 1,
|
|
'early_stopping': False,
|
|
'mirostat_mode': 0,
|
|
'mirostat_tau': 5,
|
|
'mirostat_eta': 0.1,
|
|
'ban_eos_token': False,
|
|
'skip_special_tokens': True,
|
|
}
|
|
|
|
# fixup absolute 0.0's
|
|
for par in ['temperature', 'repetition_penalty', 'encoder_repetition_penalty']:
|
|
req_params[par] = clamp(req_params[par], 0.001, 1.999)
|
|
|
|
self.send_response(200)
|
|
if req_params['stream']:
|
|
self.send_header('Content-Type', 'text/event-stream')
|
|
self.send_header('Cache-Control', 'no-cache')
|
|
# self.send_header('Connection', 'keep-alive')
|
|
else:
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
token_count = 0
|
|
completion_token_count = 0
|
|
prompt = ''
|
|
stream_object_type = ''
|
|
object_type = ''
|
|
|
|
if is_chat:
|
|
# Chat Completions
|
|
stream_object_type = 'chat.completions.chunk'
|
|
object_type = 'chat.completions'
|
|
|
|
messages = body['messages']
|
|
|
|
role_formats = {
|
|
'user': 'user: {message}\n',
|
|
'assistant': 'assistant: {message}\n',
|
|
'system': '{message}',
|
|
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
|
'prompt': 'assistant:',
|
|
}
|
|
|
|
# Instruct models can be much better
|
|
try:
|
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
|
|
template = instruct['turn_template']
|
|
system_message_template = "{message}"
|
|
system_message_default = instruct['context']
|
|
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
|
|
|
role_formats = {
|
|
'user': user_message_template,
|
|
'assistant': bot_message_template,
|
|
'system': system_message_template,
|
|
'context': system_message_default,
|
|
'prompt': bot_prompt,
|
|
}
|
|
|
|
if debug:
|
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
|
except:
|
|
if debug:
|
|
print("Loaded default role format.")
|
|
|
|
system_msgs = []
|
|
chat_msgs = []
|
|
|
|
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
|
if context_msg:
|
|
system_msgs.extend([context_msg])
|
|
|
|
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
|
if 'prompt' in body:
|
|
prompt_msg = role_formats['system'].format(message=body['prompt'])
|
|
system_msgs.extend([prompt_msg])
|
|
|
|
for m in messages:
|
|
role = m['role']
|
|
content = m['content']
|
|
msg = role_formats[role].format(message=content)
|
|
if role == 'system':
|
|
system_msgs.extend([msg])
|
|
else:
|
|
chat_msgs.extend([msg])
|
|
|
|
# can't really truncate the system messages
|
|
system_msg = '\n'.join(system_msgs)
|
|
if system_msg[-1] != '\n':
|
|
system_msg = system_msg + '\n'
|
|
|
|
system_token_count = len(encode(system_msg)[0])
|
|
remaining_tokens = req_params['truncation_length'] - system_token_count
|
|
chat_msg = ''
|
|
|
|
while chat_msgs:
|
|
new_msg = chat_msgs.pop()
|
|
new_size = len(encode(new_msg)[0])
|
|
if new_size <= remaining_tokens:
|
|
chat_msg = new_msg + chat_msg
|
|
remaining_tokens -= new_size
|
|
else:
|
|
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
|
|
break
|
|
|
|
prompt = system_msg + chat_msg + role_formats['prompt']
|
|
|
|
token_count = len(encode(prompt)[0])
|
|
|
|
else:
|
|
# Text Completions
|
|
stream_object_type = 'text_completion.chunk'
|
|
object_type = 'text_completion'
|
|
|
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
|
if is_legacy:
|
|
prompt = body['context'] # Older engines.generate API
|
|
else:
|
|
prompt = body['prompt'] # XXX this can be different types
|
|
|
|
if isinstance(prompt, list):
|
|
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
|
|
|
|
token_count = len(encode(prompt)[0])
|
|
if token_count >= req_params['truncation_length']:
|
|
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
|
prompt = prompt[-new_len:]
|
|
new_token_count = len(encode(prompt)[0])
|
|
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
|
token_count = new_token_count
|
|
|
|
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
|
|
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
|
|
|
# pass with some expected stop strings.
|
|
# some strange cases of "##| Instruction: " sneaking through.
|
|
stopping_strings += standard_stopping_strings
|
|
req_params['custom_stopping_strings'] = stopping_strings
|
|
|
|
if req_params['stream']:
|
|
shared.args.chat = True
|
|
# begin streaming
|
|
chunk = {
|
|
"id": cmpl_id,
|
|
"object": stream_object_type,
|
|
"created": created_time,
|
|
"model": shared.model_name,
|
|
resp_list: [{
|
|
"index": 0,
|
|
"finish_reason": None,
|
|
}],
|
|
}
|
|
|
|
if stream_object_type == 'text_completion.chunk':
|
|
chunk[resp_list][0]["text"] = ""
|
|
else:
|
|
# So yeah... do both methods? delta and messages.
|
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
|
|
|
response = 'data: ' + json.dumps(chunk) + '\n'
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
# generate reply #######################################
|
|
if debug:
|
|
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
|
|
answer = ''
|
|
seen_content = ''
|
|
longest_stop_len = max([len(x) for x in stopping_strings])
|
|
|
|
for a in generator:
|
|
answer = a
|
|
|
|
stop_string_found = False
|
|
len_seen = len(seen_content)
|
|
search_start = max(len_seen - longest_stop_len, 0)
|
|
|
|
for string in stopping_strings:
|
|
idx = answer.find(string, search_start)
|
|
if idx != -1:
|
|
answer = answer[:idx] # clip it.
|
|
stop_string_found = True
|
|
|
|
if stop_string_found:
|
|
break
|
|
|
|
# If something like "\nYo" is generated just before "\nYou:"
|
|
# is completed, buffer and generate more, don't send it
|
|
buffer_and_continue = False
|
|
|
|
for string in stopping_strings:
|
|
for j in range(len(string) - 1, 0, -1):
|
|
if answer[-j:] == string[:j]:
|
|
buffer_and_continue = True
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
|
|
if buffer_and_continue:
|
|
continue
|
|
|
|
if req_params['stream']:
|
|
# Streaming
|
|
new_content = answer[len_seen:]
|
|
|
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
|
continue
|
|
|
|
seen_content = answer
|
|
chunk = {
|
|
"id": cmpl_id,
|
|
"object": stream_object_type,
|
|
"created": created_time,
|
|
"model": shared.model_name,
|
|
resp_list: [{
|
|
"index": 0,
|
|
"finish_reason": None,
|
|
}],
|
|
}
|
|
|
|
# strip extra leading space off new generated content
|
|
if len_seen == 0 and new_content[0] == ' ':
|
|
new_content = new_content[1:]
|
|
|
|
if stream_object_type == 'text_completion.chunk':
|
|
chunk[resp_list][0]['text'] = new_content
|
|
else:
|
|
# So yeah... do both methods? delta and messages.
|
|
chunk[resp_list][0]['message'] = {'content': new_content}
|
|
chunk[resp_list][0]['delta'] = {'content': new_content}
|
|
response = 'data: ' + json.dumps(chunk) + '\n'
|
|
self.wfile.write(response.encode('utf-8'))
|
|
completion_token_count += len(encode(new_content)[0])
|
|
|
|
if req_params['stream']:
|
|
chunk = {
|
|
"id": cmpl_id,
|
|
"object": stream_object_type,
|
|
"created": created_time,
|
|
"model": model, # TODO: add Lora info?
|
|
resp_list: [{
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": token_count,
|
|
"completion_tokens": completion_token_count,
|
|
"total_tokens": token_count + completion_token_count
|
|
}
|
|
}
|
|
if stream_object_type == 'text_completion.chunk':
|
|
chunk[resp_list][0]['text'] = ''
|
|
else:
|
|
# So yeah... do both methods? delta and messages.
|
|
chunk[resp_list][0]['message'] = {'content': ''}
|
|
chunk[resp_list][0]['delta'] = {}
|
|
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
|
|
self.wfile.write(response.encode('utf-8'))
|
|
# Finished if streaming.
|
|
if debug:
|
|
if answer and answer[0] == ' ':
|
|
answer = answer[1:]
|
|
print({'answer': answer}, chunk)
|
|
return
|
|
|
|
# strip extra leading space off new generated content
|
|
if answer and answer[0] == ' ':
|
|
answer = answer[1:]
|
|
|
|
if debug:
|
|
print({'response': answer})
|
|
|
|
completion_token_count = len(encode(answer)[0])
|
|
stop_reason = "stop"
|
|
if token_count + completion_token_count >= req_params['truncation_length']:
|
|
stop_reason = "length"
|
|
|
|
resp = {
|
|
"id": cmpl_id,
|
|
"object": object_type,
|
|
"created": created_time,
|
|
"model": model, # TODO: add Lora info?
|
|
resp_list: [{
|
|
"index": 0,
|
|
"finish_reason": stop_reason,
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": token_count,
|
|
"completion_tokens": completion_token_count,
|
|
"total_tokens": token_count + completion_token_count
|
|
}
|
|
}
|
|
|
|
if is_chat:
|
|
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
|
else:
|
|
resp[resp_list][0]["text"] = answer
|
|
|
|
response = json.dumps(resp)
|
|
self.wfile.write(response.encode('utf-8'))
|
|
elif '/edits' in self.path:
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
created_time = int(time.time())
|
|
|
|
# Using Alpaca format, this may work with other models too.
|
|
instruction = body['instruction']
|
|
input = body.get('input', '')
|
|
|
|
instruction_template = deduce_template()
|
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
|
|
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
token_count = len(encode(edit_task)[0])
|
|
max_tokens = truncation_length - token_count
|
|
|
|
req_params = {
|
|
'max_new_tokens': max_tokens,
|
|
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999),
|
|
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0),
|
|
'top_k': 1,
|
|
'repetition_penalty': 1.18,
|
|
'encoder_repetition_penalty': 1.0,
|
|
'suffix': None,
|
|
'stream': False,
|
|
'echo': False,
|
|
'seed': shared.settings.get('seed', -1),
|
|
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
|
'truncation_length': truncation_length,
|
|
'add_bos_token': shared.settings.get('add_bos_token', True),
|
|
'do_sample': True,
|
|
'typical_p': 1.0,
|
|
'epsilon_cutoff': 0, # In units of 1e-4
|
|
'eta_cutoff': 0, # In units of 1e-4
|
|
'min_length': 0,
|
|
'no_repeat_ngram_size': 0,
|
|
'num_beams': 1,
|
|
'penalty_alpha': 0.0,
|
|
'length_penalty': 1,
|
|
'early_stopping': False,
|
|
'mirostat_mode': 0,
|
|
'mirostat_tau': 5,
|
|
'mirostat_eta': 0.1,
|
|
'ban_eos_token': False,
|
|
'skip_special_tokens': True,
|
|
'custom_stopping_strings': [],
|
|
}
|
|
|
|
if debug:
|
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
|
|
|
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False)
|
|
|
|
answer = ''
|
|
for a in generator:
|
|
answer = a
|
|
|
|
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
|
answer = answer[1:]
|
|
|
|
completion_token_count = len(encode(answer)[0])
|
|
|
|
resp = {
|
|
"object": "edit",
|
|
"created": created_time,
|
|
"choices": [{
|
|
"text": answer,
|
|
"index": 0,
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": token_count,
|
|
"completion_tokens": completion_token_count,
|
|
"total_tokens": token_count + completion_token_count
|
|
}
|
|
}
|
|
|
|
if debug:
|
|
print({'answer': answer, 'completion_token_count': completion_token_count})
|
|
|
|
response = json.dumps(resp)
|
|
self.wfile.write(response.encode('utf-8'))
|
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
|
# Stable Diffusion callout wrapper for txt2img
|
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
|
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
|
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
|
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
|
# Will probably work best with the stock SD models.
|
|
# SD configuration is beyond the scope of this API.
|
|
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
|
# require changing the form data handling to accept multipart form data, also to properly support
|
|
# url return types will require file management and a web serving files... Perhaps later!
|
|
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
|
|
response_format = default(body, 'response_format', 'url') # or b64_json
|
|
|
|
payload = {
|
|
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
|
|
'width': width,
|
|
'height': height,
|
|
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
|
|
}
|
|
|
|
resp = {
|
|
'created': int(time.time()),
|
|
'data': []
|
|
}
|
|
|
|
# TODO: support SD_WEBUI_AUTH username:password pair.
|
|
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
|
|
|
response = requests.post(url=sd_url, json=payload)
|
|
r = response.json()
|
|
# r['parameters']...
|
|
for b64_json in r['images']:
|
|
if response_format == 'b64_json':
|
|
resp['data'].extend([{'b64_json': b64_json}])
|
|
else:
|
|
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
|
|
|
response = json.dumps(resp)
|
|
self.wfile.write(response.encode('utf-8'))
|
|
elif '/embeddings' in self.path and embedding_model is not None:
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
input = body['input'] if 'input' in body else body['text']
|
|
if type(input) is str:
|
|
input = [input]
|
|
|
|
embeddings = embedding_model.encode(input).tolist()
|
|
|
|
def enc_emb(emb):
|
|
# If base64 is specified, encode. Otherwise, do nothing.
|
|
if body.get("encoding_format", "") == "base64":
|
|
return float_list_to_base64(emb)
|
|
else:
|
|
return emb
|
|
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
|
|
|
|
response = json.dumps({
|
|
"object": "list",
|
|
"data": data,
|
|
"model": st_model, # return the real model
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
})
|
|
|
|
if debug:
|
|
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
|
self.wfile.write(response.encode('utf-8'))
|
|
elif '/moderations' in self.path:
|
|
# for now do nothing, just don't error.
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
response = json.dumps({
|
|
"id": "modr-5MWoLO",
|
|
"model": "text-moderation-001",
|
|
"results": [{
|
|
"categories": {
|
|
"hate": False,
|
|
"hate/threatening": False,
|
|
"self-harm": False,
|
|
"sexual": False,
|
|
"sexual/minors": False,
|
|
"violence": False,
|
|
"violence/graphic": False
|
|
},
|
|
"category_scores": {
|
|
"hate": 0.0,
|
|
"hate/threatening": 0.0,
|
|
"self-harm": 0.0,
|
|
"sexual": 0.0,
|
|
"sexual/minors": 0.0,
|
|
"violence": 0.0,
|
|
"violence/graphic": 0.0
|
|
},
|
|
"flagged": False
|
|
}]
|
|
})
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/token-count':
|
|
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
tokens = encode(body['prompt'])[0]
|
|
response = json.dumps({
|
|
'results': [{
|
|
'tokens': len(tokens)
|
|
}]
|
|
})
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
print(self.path, self.headers)
|
|
self.send_error(404)
|
|
|
|
|
|
def run_server():
|
|
global embedding_model
|
|
try:
|
|
embedding_model = SentenceTransformer(st_model)
|
|
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}")
|
|
except:
|
|
print(f"\nFailed to load embedding model: {st_model}")
|
|
pass
|
|
|
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
|
server = ThreadingHTTPServer(server_addr, Handler)
|
|
if shared.args.share:
|
|
try:
|
|
from flask_cloudflared import _run_cloudflared
|
|
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
|
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
|
|
except ImportError:
|
|
print('You should install flask_cloudflared manually')
|
|
else:
|
|
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
|
|
|
server.serve_forever()
|
|
|
|
|
|
def setup():
|
|
Thread(target=run_server, daemon=True).start()
|