From 97140726926afd11479aed80e5b4d6b5930119be Mon Sep 17 00:00:00 2001 From: matatonic <73265741+matatonic@users.noreply.github.com> Date: Tue, 23 May 2023 18:58:41 -0400 Subject: [PATCH] [extensions/openai] use instruction templates with chat_completions (#2291) --- extensions/openai/script.py | 121 +++++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 49a16b07..d41592a3 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -142,6 +142,16 @@ class Handler(BaseHTTPRequestHandler): "permission": [] }) + self.wfile.write(response.encode('utf-8')) + elif '/billing/usage' in self.path: + # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + response = json.dumps({ + "total_usage": 0, + }) self.wfile.write(response.encode('utf-8')) else: self.send_error(404) @@ -164,7 +174,8 @@ class Handler(BaseHTTPRequestHandler): # model = body.get('model', shared.model_name) # ignored, use existing for now model = shared.model_name created_time = int(time.time()) - cmpl_id = "conv-%d" % (created_time) + + cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time) # Try to use openai defaults or map them to something with the same intent stopping_strings = default(shared.settings, 'custom_stopping_strings', []) @@ -181,10 +192,7 @@ class Handler(BaseHTTPRequestHandler): max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) - - # hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max - while truncation_length <= max_tokens: - max_tokens = max_tokens // 2 + # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough req_params = { 'max_new_tokens': max_tokens, @@ -243,33 +251,75 @@ class Handler(BaseHTTPRequestHandler): object_type = '' if is_chat: + # Chat Completions stream_object_type = 'chat.completions.chunk' object_type = 'chat.completions' messages = body['messages'] - system_msgs = [] - if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. - system_msgs = [ body['prompt'] ] + role_formats = { + 'user': 'user: {message}\n', + 'bot': 'assistant: {message}\n', + 'system': '{message}', + 'context': 'You are a helpful assistant. Answer as concisely as possible.', + 'prompt': 'assistant:', + } + # Instruct models can be much better + try: + instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) + + template = instruct['turn_template'] + system_message_template = "{message}" + system_message_default = instruct['context'] + bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token + user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user']) + bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot']) + bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') + + role_formats = { + 'user': user_message_template, + 'assistant': bot_message_template, + 'system': system_message_template, + 'context': system_message_default, + 'prompt': bot_prompt, + } + + if debug: + print(f"Loaded instruction role format: {shared.settings['instruction_template']}") + except: + if debug: + print("Loaded default role format.") + + system_msgs = [] chat_msgs = [] + # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} + context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' + if context_msg: + system_msgs.extend([context_msg]) + + # Maybe they sent both? This is not documented in the API, but some clients seem to do this. + if 'prompt' in body: + prompt_msg = role_formats['system'].format(message=body['prompt']) + system_msgs.extend([prompt_msg]) + for m in messages: role = m['role'] content = m['content'] - # name = m.get('name', 'user') + msg = role_formats[role].format(message=content) if role == 'system': - system_msgs.extend([content.strip()]) + system_msgs.extend([msg]) else: - chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed? + chat_msgs.extend([msg]) - # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} - system_msg = 'You are assistant, a large language model. Answer as concisely as possible.' - if system_msgs: - system_msg = '\n'.join(system_msgs) + # can't really truncate the system messages + system_msg = '\n'.join(system_msgs) + if system_msg[-1] != '\n': + system_msg = system_msg + '\n' system_token_count = len(encode(system_msg)[0]) - remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count + remaining_tokens = req_params['truncation_length'] - system_token_count chat_msg = '' while chat_msgs: @@ -279,25 +329,15 @@ class Handler(BaseHTTPRequestHandler): chat_msg = new_msg + chat_msg remaining_tokens -= new_size else: - # TODO: clip a message to fit? - # ie. user: ... + print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).") break - if len(chat_msgs) > 0: - print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") - - if system_msg: - prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant:' - else: - prompt = chat_msg + '\nassistant:' + prompt = system_msg + chat_msg + role_formats['prompt'] token_count = len(encode(prompt)[0]) - # pass with some expected stop strings. - # some strange cases of "##| Instruction: " sneaking through. - stopping_strings += standard_stopping_strings - req_params['custom_stopping_strings'] = stopping_strings else: + # Text Completions stream_object_type = 'text_completion.chunk' object_type = 'text_completion' @@ -312,14 +352,21 @@ class Handler(BaseHTTPRequestHandler): token_count = len(encode(prompt)[0]) if token_count >= req_params['truncation_length']: - new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) + new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count) prompt = prompt[-new_len:] - print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.") + new_token_count = len(encode(prompt)[0]) + print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.") + token_count = new_token_count - # pass with some expected stop strings. - # some strange cases of "##| Instruction: " sneaking through. - stopping_strings += standard_stopping_strings - req_params['custom_stopping_strings'] = stopping_strings + if req_params['truncation_length'] - token_count < req_params['max_new_tokens']: + print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}") + req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") + + # pass with some expected stop strings. + # some strange cases of "##| Instruction: " sneaking through. + stopping_strings += standard_stopping_strings + req_params['custom_stopping_strings'] = stopping_strings if req_params['stream']: shared.args.chat = True @@ -338,11 +385,9 @@ class Handler(BaseHTTPRequestHandler): if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]["text"] = "" else: - # This is coming back as "system" to the openapi cli, not sure why. # So yeah... do both methods? delta and messages. chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} - # { "role": "assistant" } response = 'data: ' + json.dumps(chunk) + '\n' self.wfile.write(response.encode('utf-8')) @@ -449,7 +494,7 @@ class Handler(BaseHTTPRequestHandler): if debug: if answer and answer[0] == ' ': answer = answer[1:] - print({'response': answer}) + print({'answer': answer}, chunk) return # strip extra leading space off new generated content