From 320fcfde4efd060c08bdc8ed52ef48b27026fe0c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 2 May 2023 23:05:38 -0300 Subject: [PATCH] Style/pep8 improvements --- extensions/openai/script.py | 130 +++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index e0ce4b37..d1e469f3 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,8 +1,9 @@ -import json, time, os +import json +import os +import time from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread - from modules import shared from modules.text_generation import encode, generate_reply @@ -25,13 +26,15 @@ embedding_model = None standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] # little helper to get defaults if arg is present but None and should be the same type as default. + + def default(dic, key, default): val = dic.get(key, default) if type(val) != type(default): # maybe it's just something like 1 instead of 1.0 try: v = type(default)(val) - if type(val)(v) == val: # if it's the same value passed in, it's ok. + if type(val)(v) == val: # if it's the same value passed in, it's ok. return v except: pass @@ -39,6 +42,7 @@ def default(dic, key, default): val = default return val + def clamp(value, minvalue, maxvalue): return max(minvalue, min(value, maxvalue)) @@ -54,27 +58,27 @@ class Handler(BaseHTTPRequestHandler): # TODO: list all models and allow model changes via API? Lora's? # This API should list capabilities, limits and pricing... models = [{ - "id": shared.model_name, # The real chat/completions model + "id": shared.model_name, # The real chat/completions model "object": "model", "owned_by": "user", "permission": [] - }, { - "id": st_model, # The real sentence transformer embeddings model + }, { + "id": st_model, # The real sentence transformer embeddings model "object": "model", "owned_by": "user", "permission": [] - }, { # these are expected by so much, so include some here as a dummy - "id": "gpt-3.5-turbo", # /v1/chat/completions + }, { # these are expected by so much, so include some here as a dummy + "id": "gpt-3.5-turbo", # /v1/chat/completions "object": "model", "owned_by": "user", "permission": [] - }, { - "id": "text-curie-001", # /v1/completions, 2k context + }, { + "id": "text-curie-001", # /v1/completions, 2k context "object": "model", "owned_by": "user", "permission": [] - }, { - "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 + }, { + "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 "object": "model", "owned_by": "user", "permission": [] @@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler): content_length = int(self.headers['Content-Length']) body = json.loads(self.rfile.read(content_length).decode('utf-8')) - if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? - if debug: print(body) + if debug: + print(self.headers) # did you know... python-openai sends your linux kernel & python version? + if debug: + print(body) if '/completions' in self.path or '/generate' in self.path: is_legacy = '/generate' in self.path @@ -112,7 +118,7 @@ class Handler(BaseHTTPRequestHandler): resp_list = 'data' if is_legacy else 'choices' # XXX model is ignored for now - #model = body.get('model', shared.model_name) # ignored, use existing for now + # model = body.get('model', shared.model_name) # ignored, use existing for now model = shared.model_name created_time = int(time.time()) cmpl_id = "conv-%d" % (created_time) @@ -129,11 +135,11 @@ class Handler(BaseHTTPRequestHandler): truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) - default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf" + default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf" max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) - + # hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max while truncation_length <= max_tokens: max_tokens = max_tokens // 2 @@ -143,17 +149,17 @@ class Handler(BaseHTTPRequestHandler): 'temperature': default(body, 'temperature', 1.0), 'top_p': default(body, 'top_p', 1.0), 'top_k': default(body, 'best_of', 1), - ### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 + # XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 # 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. - 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. - ### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. - 'encoder_repetition_penalty': 1.0, #(default(body, 'frequency_penalty', 0) + 2.0) / 2.0, + 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. + # XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. + 'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0, 'suffix': body.get('suffix', None), 'stream': default(body, 'stream', False), 'echo': default(body, 'echo', False), ##################################################### 'seed': shared.settings.get('seed', -1), - #int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map + # int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map # unofficial, but it needs to get set anyways. 'truncation_length': truncation_length, # no more args. @@ -178,7 +184,7 @@ class Handler(BaseHTTPRequestHandler): if req_params['stream']: self.send_header('Content-Type', 'text/event-stream') self.send_header('Cache-Control', 'no-cache') - #self.send_header('Connection', 'keep-alive') + # self.send_header('Connection', 'keep-alive') else: self.send_header('Content-Type', 'application/json') self.end_headers() @@ -195,8 +201,8 @@ class Handler(BaseHTTPRequestHandler): messages = body['messages'] - system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} - if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. + system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} + if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. system_msg = body['prompt'] chat_msgs = [] @@ -204,16 +210,16 @@ class Handler(BaseHTTPRequestHandler): for m in messages: role = m['role'] content = m['content'] - #name = m.get('name', 'user') + # name = m.get('name', 'user') if role == 'system': system_msg += content else: - chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed? - + chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed? + system_token_count = len(encode(system_msg)[0]) remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count chat_msg = '' - + while chat_msgs: new_msg = chat_msgs.pop() new_size = len(encode(new_msg)[0]) @@ -229,7 +235,7 @@ class Handler(BaseHTTPRequestHandler): print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") if system_msg: - prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' + prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' else: prompt = chat_msg + '\nassistant: ' @@ -245,16 +251,16 @@ class Handler(BaseHTTPRequestHandler): # ... encoded as a string, array of strings, array of tokens, or array of token arrays. if is_legacy: - prompt = body['context'] # Older engines.generate API + prompt = body['context'] # Older engines.generate API else: - prompt = body['prompt'] # XXX this can be different types - + prompt = body['prompt'] # XXX this can be different types + if isinstance(prompt, list): - prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? + prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? token_count = len(encode(prompt)[0]) if token_count >= req_params['truncation_length']: - new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) + new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) prompt = prompt[-new_len:] print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.") @@ -262,7 +268,6 @@ class Handler(BaseHTTPRequestHandler): # some strange cases of "##| Instruction: " sneaking through. stopping_strings += standard_stopping_strings req_params['custom_stopping_strings'] = stopping_strings - shared.args.no_stream = not req_params['stream'] if not shared.args.no_stream: @@ -283,22 +288,23 @@ class Handler(BaseHTTPRequestHandler): chunk[resp_list][0]["text"] = "" else: # This is coming back as "system" to the openapi cli, not sure why. - # So yeah... do both methods? delta and messages. + # So yeah... do both methods? delta and messages. chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} - #{ "role": "assistant" } + # { "role": "assistant" } response = 'data: ' + json.dumps(chunk) + '\n' self.wfile.write(response.encode('utf-8')) # generate reply ####################################### - if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) + if debug: + print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings) answer = '' seen_content = '' - longest_stop_len = max([ len(x) for x in stopping_strings ]) - + longest_stop_len = max([len(x) for x in stopping_strings]) + for a in generator: if isinstance(a, str): answer = a @@ -312,7 +318,7 @@ class Handler(BaseHTTPRequestHandler): for string in stopping_strings: idx = answer.find(string, search_start) if idx != -1: - answer = answer[:idx] # clip it. + answer = answer[:idx] # clip it. stop_string_found = True if stop_string_found: @@ -338,9 +344,9 @@ class Handler(BaseHTTPRequestHandler): # Streaming new_content = answer[len_seen:] - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue - + seen_content = answer chunk = { "id": cmpl_id, @@ -355,9 +361,9 @@ class Handler(BaseHTTPRequestHandler): if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = new_content else: - # So yeah... do both methods? delta and messages. - chunk[resp_list][0]['message'] = { 'content': new_content } - chunk[resp_list][0]['delta'] = { 'content': new_content } + # So yeah... do both methods? delta and messages. + chunk[resp_list][0]['message'] = {'content': new_content} + chunk[resp_list][0]['delta'] = {'content': new_content} response = 'data: ' + json.dumps(chunk) + '\n' self.wfile.write(response.encode('utf-8')) completion_token_count += len(encode(new_content)[0]) @@ -367,7 +373,7 @@ class Handler(BaseHTTPRequestHandler): "id": cmpl_id, "object": stream_object_type, "created": created_time, - "model": model, # TODO: add Lora info? + "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": "stop", @@ -381,16 +387,18 @@ class Handler(BaseHTTPRequestHandler): if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = '' else: - # So yeah... do both methods? delta and messages. - chunk[resp_list][0]['message'] = {'content': '' } + # So yeah... do both methods? delta and messages. + chunk[resp_list][0]['message'] = {'content': ''} chunk[resp_list][0]['delta'] = {} response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' self.wfile.write(response.encode('utf-8')) - ###### Finished if streaming. - if debug: print({'response': answer}) + # Finished if streaming. + if debug: + print({'response': answer}) return - - if debug: print({'response': answer}) + + if debug: + print({'response': answer}) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" @@ -401,7 +409,7 @@ class Handler(BaseHTTPRequestHandler): "id": cmpl_id, "object": object_type, "created": created_time, - "model": model, # TODO: add Lora info? + "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": stop_reason, @@ -414,13 +422,13 @@ class Handler(BaseHTTPRequestHandler): } if is_chat: - resp[resp_list][0]["message"] = {"role": "assistant", "content": answer } + resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} else: resp[resp_list][0]["text"] = answer response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) - elif '/embeddings' in self.path and embedding_model != None: + elif '/embeddings' in self.path and embedding_model is not None: self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() @@ -431,19 +439,20 @@ class Handler(BaseHTTPRequestHandler): embeddings = embedding_model.encode(input).tolist() - data = [ {"object": "embedding", "embedding": emb, "index": n } for n, emb in enumerate(embeddings) ] + data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)] response = json.dumps({ "object": "list", "data": data, - "model": st_model, # return the real model + "model": st_model, # return the real model "usage": { "prompt_tokens": 0, "total_tokens": 0, } }) - if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") + if debug: + print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") self.wfile.write(response.encode('utf-8')) elif '/moderations' in self.path: # for now do nothing, just don't error. @@ -521,4 +530,3 @@ def run_server(): def setup(): Thread(target=run_server, daemon=True).start() -