import base64 import json import os import time import requests import yaml from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread from modules.utils import get_available_models import numpy as np from modules import shared from modules.text_generation import encode, generate_reply params = { 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, } debug = True if 'OPENEDAI_DEBUG' in os.environ else False # Slightly different defaults for OpenAI's API default_req_params = { 'max_new_tokens': 200, 'temperature': 1.0, 'top_p': 1.0, 'top_k': 1, 'repetition_penalty': 1.18, 'encoder_repetition_penalty': 1.0, 'suffix': None, 'stream': False, 'echo': False, 'seed': -1, # 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map 'truncation_length': 2048, 'add_bos_token': True, 'do_sample': True, 'typical_p': 1.0, 'epsilon_cutoff': 0, # In units of 1e-4 'eta_cutoff': 0, # In units of 1e-4 'tfs': 1.0, 'top_a': 0.0, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0.0, 'length_penalty': 1, 'early_stopping': False, 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, 'ban_eos_token': False, 'skip_special_tokens': True, 'custom_stopping_strings': [], } # Optional, install the module and download the model to enable # v1/embeddings try: from sentence_transformers import SentenceTransformer except ImportError: pass st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" embedding_model = None standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] # little helper to get defaults if arg is present but None and should be the same type as default. def default(dic, key, default): val = dic.get(key, default) if type(val) != type(default): # maybe it's just something like 1 instead of 1.0 try: v = type(default)(val) if type(val)(v) == val: # if it's the same value passed in, it's ok. return v except: pass val = default return val def clamp(value, minvalue, maxvalue): return max(minvalue, min(value, maxvalue)) def deduce_template(): # Alpaca is verbose so a good default prompt default_template = ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" ) # Use the special instruction/input/response template for anything trained like Alpaca if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']: return default_template try: instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) template = instruct['turn_template'] template = template\ .replace('<|user|>', instruct.get('user', ''))\ .replace('<|bot|>', instruct.get('bot', ''))\ .replace('<|user-message|>', '{instruction}\n{input}') return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') except: return default_template def float_list_to_base64(float_list): # Convert the list to a float32 array that the OpenAPI client expects float_array = np.array(float_list, dtype="float32") # Get raw bytes bytes_array = float_array.tobytes() # Encode bytes into base64 encoded_bytes = base64.b64encode(bytes_array) # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode('ascii') return ascii_string class Handler(BaseHTTPRequestHandler): def send_access_control_headers(self): self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Credentials", "true") self.send_header( "Access-Control-Allow-Methods", "GET,HEAD,OPTIONS,POST,PUT" ) self.send_header( "Access-Control-Allow-Headers", "Origin, Accept, X-Requested-With, Content-Type, " "Access-Control-Request-Method, Access-Control-Request-Headers, " "Authorization" ) def do_OPTIONS(self): self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() self.wfile.write("OK".encode('utf-8')) def do_GET(self): if self.path.startswith('/v1/models'): self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() # TODO: list all models and allow model changes via API? Lora's? # This API should list capabilities, limits and pricing... models = [{ "id": shared.model_name, # The real chat/completions model "object": "model", "owned_by": "user", "permission": [] }, { "id": st_model, # The real sentence transformer embeddings model "object": "model", "owned_by": "user", "permission": [] }, { # these are expected by so much, so include some here as a dummy "id": "gpt-3.5-turbo", # /v1/chat/completions "object": "model", "owned_by": "user", "permission": [] }, { "id": "text-curie-001", # /v1/completions, 2k context "object": "model", "owned_by": "user", "permission": [] }, { "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 "object": "model", "owned_by": "user", "permission": [] }] models.extend([{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in get_available_models() ]) response = '' if self.path == '/v1/models': response = json.dumps({ "object": "list", "data": models, }) else: the_model_name = self.path[len('/v1/models/'):] response = json.dumps({ "id": the_model_name, "object": "model", "owned_by": "user", "permission": [] }) self.wfile.write(response.encode('utf-8')) elif '/billing/usage' in self.path: # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() response = json.dumps({ "total_usage": 0, }) self.wfile.write(response.encode('utf-8')) else: self.send_error(404) def do_POST(self): if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? content_length = int(self.headers['Content-Length']) body = json.loads(self.rfile.read(content_length).decode('utf-8')) if debug: print(body) if '/completions' in self.path or '/generate' in self.path: is_legacy = '/generate' in self.path is_chat = 'chat' in self.path resp_list = 'data' if is_legacy else 'choices' # XXX model is ignored for now # model = body.get('model', shared.model_name) # ignored, use existing for now model = shared.model_name created_time = int(time.time()) cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time) # Try to use openai defaults or map them to something with the same intent stopping_strings = default(shared.settings, 'custom_stopping_strings', []) if 'stop' in body: if isinstance(body['stop'], str): stopping_strings = [body['stop']] elif isinstance(body['stop'], list): stopping_strings = body['stop'] truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it. max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough req_params = default_req_params.copy() req_params['max_new_tokens'] = max_tokens req_params['truncation_length'] = truncation_length req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) req_params['top_k'] = default(body, 'best_of', default_req_params['top_k']) req_params['suffix'] = default(body, 'suffix', default_req_params['suffix']) req_params['stream'] = default(body, 'stream', default_req_params['stream']) req_params['echo'] = default(body, 'echo', default_req_params['echo']) req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) self.send_response(200) self.send_access_control_headers() if req_params['stream']: self.send_header('Content-Type', 'text/event-stream') self.send_header('Cache-Control', 'no-cache') # self.send_header('Connection', 'keep-alive') else: self.send_header('Content-Type', 'application/json') self.end_headers() token_count = 0 completion_token_count = 0 prompt = '' stream_object_type = '' object_type = '' if is_chat: # Chat Completions stream_object_type = 'chat.completions.chunk' object_type = 'chat.completions' messages = body['messages'] role_formats = { 'user': 'user: {message}\n', 'assistant': 'assistant: {message}\n', 'system': '{message}', 'context': 'You are a helpful assistant. Answer as concisely as possible.', 'prompt': 'assistant:', } # Instruct models can be much better try: instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) template = instruct['turn_template'] system_message_template = "{message}" system_message_default = instruct['context'] bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user']) bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot']) bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') role_formats = { 'user': user_message_template, 'assistant': bot_message_template, 'system': system_message_template, 'context': system_message_default, 'prompt': bot_prompt, } if debug: print(f"Loaded instruction role format: {shared.settings['instruction_template']}") except: if debug: print("Loaded default role format.") system_msgs = [] chat_msgs = [] # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' if context_msg: system_msgs.extend([context_msg]) # Maybe they sent both? This is not documented in the API, but some clients seem to do this. if 'prompt' in body: prompt_msg = role_formats['system'].format(message=body['prompt']) system_msgs.extend([prompt_msg]) for m in messages: role = m['role'] content = m['content'] msg = role_formats[role].format(message=content) if role == 'system': system_msgs.extend([msg]) else: chat_msgs.extend([msg]) # can't really truncate the system messages system_msg = '\n'.join(system_msgs) if system_msg and system_msg[-1] != '\n': system_msg = system_msg + '\n' system_token_count = len(encode(system_msg)[0]) remaining_tokens = req_params['truncation_length'] - system_token_count chat_msg = '' while chat_msgs: new_msg = chat_msgs.pop() new_size = len(encode(new_msg)[0]) if new_size <= remaining_tokens: chat_msg = new_msg + chat_msg remaining_tokens -= new_size else: print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).") break prompt = system_msg + chat_msg + role_formats['prompt'] token_count = len(encode(prompt)[0]) else: # Text Completions stream_object_type = 'text_completion.chunk' object_type = 'text_completion' # ... encoded as a string, array of strings, array of tokens, or array of token arrays. if is_legacy: prompt = body['context'] # Older engines.generate API else: prompt = body['prompt'] # XXX this can be different types if isinstance(prompt, list): prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? token_count = len(encode(prompt)[0]) if token_count >= req_params['truncation_length']: new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count) prompt = prompt[-new_len:] new_token_count = len(encode(prompt)[0]) print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.") token_count = new_token_count if req_params['truncation_length'] - token_count < req_params['max_new_tokens']: print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}") req_params['max_new_tokens'] = req_params['truncation_length'] - token_count print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") # pass with some expected stop strings. # some strange cases of "##| Instruction: " sneaking through. stopping_strings += standard_stopping_strings req_params['custom_stopping_strings'] = stopping_strings if req_params['stream']: shared.args.chat = True # begin streaming chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, }], } if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]["text"] = "" else: # So yeah... do both methods? delta and messages. chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' chunk_size = hex(len(data_chunk))[2:] + '\r\n' response = chunk_size + data_chunk self.wfile.write(response.encode('utf-8')) # generate reply ####################################### if debug: print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) answer = '' seen_content = '' longest_stop_len = max([len(x) for x in stopping_strings]) for a in generator: answer = a stop_string_found = False len_seen = len(seen_content) search_start = max(len_seen - longest_stop_len, 0) for string in stopping_strings: idx = answer.find(string, search_start) if idx != -1: answer = answer[:idx] # clip it. stop_string_found = True if stop_string_found: break # If something like "\nYo" is generated just before "\nYou:" # is completed, buffer and generate more, don't send it buffer_and_continue = False for string in stopping_strings: for j in range(len(string) - 1, 0, -1): if answer[-j:] == string[:j]: buffer_and_continue = True break else: continue break if buffer_and_continue: continue if req_params['stream']: # Streaming new_content = answer[len_seen:] if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue seen_content = answer chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, }], } # strip extra leading space off new generated content if len_seen == 0 and new_content[0] == ' ': new_content = new_content[1:] if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = new_content else: # So yeah... do both methods? delta and messages. chunk[resp_list][0]['message'] = {'content': new_content} chunk[resp_list][0]['delta'] = {'content': new_content} data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' chunk_size = hex(len(data_chunk))[2:] + '\r\n' response = chunk_size + data_chunk self.wfile.write(response.encode('utf-8')) completion_token_count += len(encode(new_content)[0]) if req_params['stream']: chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": "stop", }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = '' else: # So yeah... do both methods? delta and messages. chunk[resp_list][0]['message'] = {'content': ''} chunk[resp_list][0]['delta'] = {'content': ''} data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' chunk_size = hex(len(data_chunk))[2:] + '\r\n' done = 'data: [DONE]\r\n\r\n' response = chunk_size + data_chunk + done self.wfile.write(response.encode('utf-8')) # Finished if streaming. if debug: if answer and answer[0] == ' ': answer = answer[1:] print({'answer': answer}, chunk) return # strip extra leading space off new generated content if answer and answer[0] == ' ': answer = answer[1:] if debug: print({'response': answer}) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= req_params['truncation_length']: stop_reason = "length" resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": stop_reason, }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if is_chat: resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} else: resp[resp_list][0]["text"] = answer response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/edits' in self.path: self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() created_time = int(time.time()) # Using Alpaca format, this may work with other models too. instruction = body['instruction'] input = body.get('input', '') instruction_template = deduce_template() edit_task = instruction_template.format(instruction=instruction, input=input) truncation_length = default(shared.settings, 'truncation_length', 2048) token_count = len(encode(edit_task)[0]) max_tokens = truncation_length - token_count req_params = default_req_params.copy() req_params['max_new_tokens'] = max_tokens req_params['truncation_length'] = truncation_length req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) if debug: print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False) answer = '' for a in generator: answer = a # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. if edit_task[-1] != '\n' and answer and answer[0] == ' ': answer = answer[1:] completion_token_count = len(encode(answer)[0]) resp = { "object": "edit", "created": created_time, "choices": [{ "text": answer, "index": 0, }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if debug: print({'answer': answer, 'completion_token_count': completion_token_count}) response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: # Stable Diffusion callout wrapper for txt2img # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E # the results will be limited and likely poor. SD has hundreds of models and dozens of settings. # If you want high quality tailored results you should just use the Stable Diffusion API directly. # it's too general an API to try and shape the result with specific tags like "masterpiece", etc, # Will probably work best with the stock SD models. # SD configuration is beyond the scope of this API. # At this point I will not add the edits and variations endpoints (ie. img2img) because they # require changing the form data handling to accept multipart form data, also to properly support # url return types will require file management and a web serving files... Perhaps later! self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size response_format = default(body, 'response_format', 'url') # or b64_json payload = { 'prompt': body['prompt'], # ignore prompt limit of 1000 characters 'width': width, 'height': height, 'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10 } resp = { 'created': int(time.time()), 'data': [] } # TODO: support SD_WEBUI_AUTH username:password pair. sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" response = requests.post(url=sd_url, json=payload) r = response.json() # r['parameters']... for b64_json in r['images']: if response_format == 'b64_json': resp['data'].extend([{'b64_json': b64_json}]) else: resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/embeddings' in self.path and embedding_model is not None: self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() input = body['input'] if 'input' in body else body['text'] if type(input) is str: input = [input] embeddings = embedding_model.encode(input).tolist() def enc_emb(emb): # If base64 is specified, encode. Otherwise, do nothing. if body.get("encoding_format", "") == "base64": return float_list_to_base64(emb) else: return emb data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] response = json.dumps({ "object": "list", "data": data, "model": st_model, # return the real model "usage": { "prompt_tokens": 0, "total_tokens": 0, } }) if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") self.wfile.write(response.encode('utf-8')) elif '/moderations' in self.path: # for now do nothing, just don't error. self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() response = json.dumps({ "id": "modr-5MWoLO", "model": "text-moderation-001", "results": [{ "categories": { "hate": False, "hate/threatening": False, "self-harm": False, "sexual": False, "sexual/minors": False, "violence": False, "violence/graphic": False }, "category_scores": { "hate": 0.0, "hate/threatening": 0.0, "self-harm": 0.0, "sexual": 0.0, "sexual/minors": 0.0, "violence": 0.0, "violence/graphic": 0.0 }, "flagged": False }] }) self.wfile.write(response.encode('utf-8')) elif self.path == '/api/v1/token-count': # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() tokens = encode(body['prompt'])[0] response = json.dumps({ 'results': [{ 'tokens': len(tokens) }] }) self.wfile.write(response.encode('utf-8')) else: print(self.path, self.headers) self.send_error(404) def run_server(): global embedding_model try: embedding_model = SentenceTransformer(st_model) print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}") except: print(f"\nFailed to load embedding model: {st_model}") pass server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) server = ThreadingHTTPServer(server_addr, Handler) if shared.args.share: try: from flask_cloudflared import _run_cloudflared public_url = _run_cloudflared(params['port'], params['port'] + 1) print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') except ImportError: print('You should install flask_cloudflared manually') else: print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') server.serve_forever() def setup(): Thread(target=run_server, daemon=True).start()