From 4a17a5db678381595e2e15d9b0879a9009fe4748 Mon Sep 17 00:00:00 2001 From: matatonic <73265741+matatonic@users.noreply.github.com> Date: Tue, 6 Jun 2023 00:43:04 -0400 Subject: [PATCH] [extensions/openai] various fixes (#2533) --- extensions/openai/README.md | 8 +- extensions/openai/script.py | 192 +++++++++++++++++++----------------- 2 files changed, 109 insertions(+), 91 deletions(-) diff --git a/extensions/openai/README.md b/extensions/openai/README.md index 0a9ed20a..ffbea76e 100644 --- a/extensions/openai/README.md +++ b/extensions/openai/README.md @@ -20,6 +20,12 @@ Example: SD_WEBUI_URL=http://127.0.0.1:7861 ``` +Make sure you enable it in server launch parameters. Just make sure they include: + +``` +--extensions openai +``` + ### Embeddings (alpha) Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. @@ -42,7 +48,7 @@ Almost everything you use it with will require you to set a dummy OpenAI API key With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so: ``` -OPENAI_API_KEY=dummy +OPENAI_API_KEY=sk-dummy OPENAI_API_BASE=http://127.0.0.1:5001/v1 ``` diff --git a/extensions/openai/script.py b/extensions/openai/script.py index c5c5f2bb..97eae7b7 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -20,6 +20,7 @@ params = { debug = True if 'OPENEDAI_DEBUG' in os.environ else False # Slightly different defaults for OpenAI's API +# Data type is important, Ex. use 0.0 for a float 0 default_req_params = { 'max_new_tokens': 200, 'temperature': 1.0, @@ -44,14 +45,14 @@ default_req_params = { 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0.0, - 'length_penalty': 1, + 'length_penalty': 1.0, 'early_stopping': False, 'mirostat_mode': 0, - 'mirostat_tau': 5, + 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'ban_eos_token': False, 'skip_special_tokens': True, - 'custom_stopping_strings': [], + 'custom_stopping_strings': ['\n###'], } # Optional, install the module and download the model to enable @@ -64,8 +65,6 @@ except ImportError: st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" embedding_model = None -standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] - # little helper to get defaults if arg is present but None and should be the same type as default. def default(dic, key, default): val = dic.get(key, default) @@ -86,31 +85,6 @@ def clamp(value, minvalue, maxvalue): return max(minvalue, min(value, maxvalue)) -def deduce_template(): - # Alpaca is verbose so a good default prompt - default_template = ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ) - - # Use the special instruction/input/response template for anything trained like Alpaca - if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']: - return default_template - - try: - instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) - - template = instruct['turn_template'] - template = template\ - .replace('<|user|>', instruct.get('user', ''))\ - .replace('<|bot|>', instruct.get('bot', ''))\ - .replace('<|user-message|>', '{instruction}\n{input}') - return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') - except: - return default_template - - def float_list_to_base64(float_list): # Convert the list to a float32 array that the OpenAPI client expects float_array = np.array(float_list, dtype="float32") @@ -139,8 +113,27 @@ class Handler(BaseHTTPRequestHandler): "Origin, Accept, X-Requested-With, Content-Type, " "Access-Control-Request-Method, Access-Control-Request-Headers, " "Authorization" - ) - + ) + + def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''): + self.send_response(code) + self.send_access_control_headers() + self.send_header('Content-Type', 'application/json') + self.end_headers() + error_resp = { + 'error': { + 'message': message, + 'code': code, + 'type': error_type, + 'param': param, + } + } + if internal_message: + error_resp['internal_message'] = internal_message + + response = json.dumps(error_resp) + self.wfile.write(response.encode('utf-8')) + def do_OPTIONS(self): self.send_response(200) self.send_access_control_headers() @@ -150,42 +143,24 @@ class Handler(BaseHTTPRequestHandler): def do_GET(self): if self.path.startswith('/v1/models'): - self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() - # TODO: list all models and allow model changes via API? Lora's? + # TODO: Lora's? # This API should list capabilities, limits and pricing... - models = [{ - "id": shared.model_name, # The real chat/completions model - "object": "model", - "owned_by": "user", - "permission": [] - }, { - "id": st_model, # The real sentence transformer embeddings model - "object": "model", - "owned_by": "user", - "permission": [] - }, { # these are expected by so much, so include some here as a dummy - "id": "gpt-3.5-turbo", # /v1/chat/completions - "object": "model", - "owned_by": "user", - "permission": [] - }, { - "id": "text-curie-001", # /v1/completions, 2k context - "object": "model", - "owned_by": "user", - "permission": [] - }, { - "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 - "object": "model", - "owned_by": "user", - "permission": [] - }] + current_model_list = [ shared.model_name ] # The real chat/completions model + embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model + pseudo_model_list = [ # these are expected by so much, so include some here as a dummy + 'gpt-3.5-turbo', # /v1/chat/completions + 'text-curie-001', # /v1/completions, 2k context + 'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 + ] + available_model_list = get_available_models() + all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list - models.extend([{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in get_available_models() ]) + models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ] response = '' if self.path == '/v1/models': @@ -203,6 +178,7 @@ class Handler(BaseHTTPRequestHandler): }) self.wfile.write(response.encode('utf-8')) + elif '/billing/usage' in self.path: # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 self.send_response(200) @@ -214,6 +190,7 @@ class Handler(BaseHTTPRequestHandler): "total_usage": 0, }) self.wfile.write(response.encode('utf-8')) + else: self.send_error(404) @@ -227,6 +204,11 @@ class Handler(BaseHTTPRequestHandler): print(body) if '/completions' in self.path or '/generate' in self.path: + + if not shared.model: + self.openai_error("No model loaded.") + return + is_legacy = '/generate' in self.path is_chat = 'chat' in self.path resp_list = 'data' if is_legacy else 'choices' @@ -238,13 +220,16 @@ class Handler(BaseHTTPRequestHandler): cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time) + # Request Parameters # Try to use openai defaults or map them to something with the same intent - stopping_strings = default(shared.settings, 'custom_stopping_strings', []) + req_params = default_req_params.copy() + req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy() + if 'stop' in body: if isinstance(body['stop'], str): - stopping_strings = [body['stop']] + req_params['custom_stopping_strings'].extend([body['stop']]) elif isinstance(body['stop'], list): - stopping_strings = body['stop'] + req_params['custom_stopping_strings'].extend(body['stop']) truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) @@ -255,8 +240,6 @@ class Handler(BaseHTTPRequestHandler): max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough - req_params = default_req_params.copy() - req_params['max_new_tokens'] = max_tokens req_params['truncation_length'] = truncation_length req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 @@ -319,9 +302,14 @@ class Handler(BaseHTTPRequestHandler): 'prompt': bot_prompt, } + if instruct['user']: # WizardLM and some others have no user prompt. + req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']]) + if debug: print(f"Loaded instruction role format: {shared.settings['instruction_template']}") except: + req_params['custom_stopping_strings'].extend(['\nuser:']) + if debug: print("Loaded default role format.") @@ -396,11 +384,6 @@ class Handler(BaseHTTPRequestHandler): print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}") req_params['max_new_tokens'] = req_params['truncation_length'] - token_count print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") - - # pass with some expected stop strings. - # some strange cases of "##| Instruction: " sneaking through. - stopping_strings += standard_stopping_strings - req_params['custom_stopping_strings'] = stopping_strings if req_params['stream']: shared.args.chat = True @@ -423,19 +406,17 @@ class Handler(BaseHTTPRequestHandler): chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} - data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - chunk_size = hex(len(data_chunk))[2:] + '\r\n' - response = chunk_size + data_chunk + response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' self.wfile.write(response.encode('utf-8')) # generate reply ####################################### if debug: - print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + print({'prompt': prompt, 'req_params': req_params}) + generator = generate_reply(prompt, req_params, is_chat=False) answer = '' seen_content = '' - longest_stop_len = max([len(x) for x in stopping_strings]) + longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0]) for a in generator: answer = a @@ -444,7 +425,7 @@ class Handler(BaseHTTPRequestHandler): len_seen = len(seen_content) search_start = max(len_seen - longest_stop_len, 0) - for string in stopping_strings: + for string in req_params['custom_stopping_strings']: idx = answer.find(string, search_start) if idx != -1: answer = answer[:idx] # clip it. @@ -457,7 +438,7 @@ class Handler(BaseHTTPRequestHandler): # is completed, buffer and generate more, don't send it buffer_and_continue = False - for string in stopping_strings: + for string in req_params['custom_stopping_strings']: for j in range(len(string) - 1, 0, -1): if answer[-j:] == string[:j]: buffer_and_continue = True @@ -498,9 +479,7 @@ class Handler(BaseHTTPRequestHandler): # So yeah... do both methods? delta and messages. chunk[resp_list][0]['message'] = {'content': new_content} chunk[resp_list][0]['delta'] = {'content': new_content} - data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - chunk_size = hex(len(data_chunk))[2:] + '\r\n' - response = chunk_size + data_chunk + response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' self.wfile.write(response.encode('utf-8')) completion_token_count += len(encode(new_content)[0]) @@ -527,10 +506,7 @@ class Handler(BaseHTTPRequestHandler): chunk[resp_list][0]['message'] = {'content': ''} chunk[resp_list][0]['delta'] = {'content': ''} - data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - chunk_size = hex(len(data_chunk))[2:] + '\r\n' - done = 'data: [DONE]\r\n\r\n' - response = chunk_size + data_chunk + done + response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n' self.wfile.write(response.encode('utf-8')) # Finished if streaming. if debug: @@ -574,7 +550,12 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) + elif '/edits' in self.path: + if not shared.model: + self.openai_error("No model loaded.") + return + self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') @@ -586,15 +567,42 @@ class Handler(BaseHTTPRequestHandler): instruction = body['instruction'] input = body.get('input', '') - instruction_template = deduce_template() + # Request parameters + req_params = default_req_params.copy() + + # Alpaca is verbose so a good default prompt + default_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ) + + instruction_template = default_template + req_params['custom_stopping_strings'] = [ '\n###' ] + + # Use the special instruction/input/response template for anything trained like Alpaca + if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']): + try: + instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) + + template = instruct['turn_template'] + template = template\ + .replace('<|user|>', instruct.get('user', ''))\ + .replace('<|bot|>', instruct.get('bot', ''))\ + .replace('<|user-message|>', '{instruction}\n{input}') + + instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') + if instruct['user']: + req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ] + except: + pass + edit_task = instruction_template.format(instruction=instruction, input=input) truncation_length = default(shared.settings, 'truncation_length', 2048) token_count = len(encode(edit_task)[0]) max_tokens = truncation_length - token_count - req_params = default_req_params.copy() - req_params['max_new_tokens'] = max_tokens req_params['truncation_length'] = truncation_length req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 @@ -605,7 +613,7 @@ class Handler(BaseHTTPRequestHandler): if debug: print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) - generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False) + generator = generate_reply(edit_task, req_params, is_chat=False) answer = '' for a in generator: @@ -636,6 +644,7 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) + elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: # Stable Diffusion callout wrapper for txt2img # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E @@ -682,6 +691,7 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) + elif '/embeddings' in self.path and embedding_model is not None: self.send_response(200) self.send_access_control_headers() @@ -715,6 +725,7 @@ class Handler(BaseHTTPRequestHandler): if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") self.wfile.write(response.encode('utf-8')) + elif '/moderations' in self.path: # for now do nothing, just don't error. self.send_response(200) @@ -763,6 +774,7 @@ class Handler(BaseHTTPRequestHandler): }] }) self.wfile.write(response.encode('utf-8')) + else: print(self.path, self.headers) self.send_error(404)