mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 02:49:30 +01:00
[extensions/openai] use instruction templates with chat_completions (#2291)
This commit is contained in:
parent
74aae34beb
commit
9714072692
@ -142,6 +142,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"permission": []
|
"permission": []
|
||||||
})
|
})
|
||||||
|
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
elif '/billing/usage' in self.path:
|
||||||
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
response = json.dumps({
|
||||||
|
"total_usage": 0,
|
||||||
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
@ -164,7 +174,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
||||||
model = shared.model_name
|
model = shared.model_name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
cmpl_id = "conv-%d" % (created_time)
|
|
||||||
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
||||||
|
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
|
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
|
||||||
@ -181,10 +192,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
|
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
||||||
# hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max
|
|
||||||
while truncation_length <= max_tokens:
|
|
||||||
max_tokens = max_tokens // 2
|
|
||||||
|
|
||||||
req_params = {
|
req_params = {
|
||||||
'max_new_tokens': max_tokens,
|
'max_new_tokens': max_tokens,
|
||||||
@ -243,33 +251,75 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
object_type = ''
|
object_type = ''
|
||||||
|
|
||||||
if is_chat:
|
if is_chat:
|
||||||
|
# Chat Completions
|
||||||
stream_object_type = 'chat.completions.chunk'
|
stream_object_type = 'chat.completions.chunk'
|
||||||
object_type = 'chat.completions'
|
object_type = 'chat.completions'
|
||||||
|
|
||||||
messages = body['messages']
|
messages = body['messages']
|
||||||
|
|
||||||
system_msgs = []
|
role_formats = {
|
||||||
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
'user': 'user: {message}\n',
|
||||||
system_msgs = [ body['prompt'] ]
|
'bot': 'assistant: {message}\n',
|
||||||
|
'system': '{message}',
|
||||||
|
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||||
|
'prompt': 'assistant:',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Instruct models can be much better
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
system_message_template = "{message}"
|
||||||
|
system_message_default = instruct['context']
|
||||||
|
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||||
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||||
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||||
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||||
|
|
||||||
|
role_formats = {
|
||||||
|
'user': user_message_template,
|
||||||
|
'assistant': bot_message_template,
|
||||||
|
'system': system_message_template,
|
||||||
|
'context': system_message_default,
|
||||||
|
'prompt': bot_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
|
except:
|
||||||
|
if debug:
|
||||||
|
print("Loaded default role format.")
|
||||||
|
|
||||||
|
system_msgs = []
|
||||||
chat_msgs = []
|
chat_msgs = []
|
||||||
|
|
||||||
|
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||||
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||||
|
if context_msg:
|
||||||
|
system_msgs.extend([context_msg])
|
||||||
|
|
||||||
|
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||||
|
if 'prompt' in body:
|
||||||
|
prompt_msg = role_formats['system'].format(message=body['prompt'])
|
||||||
|
system_msgs.extend([prompt_msg])
|
||||||
|
|
||||||
for m in messages:
|
for m in messages:
|
||||||
role = m['role']
|
role = m['role']
|
||||||
content = m['content']
|
content = m['content']
|
||||||
# name = m.get('name', 'user')
|
msg = role_formats[role].format(message=content)
|
||||||
if role == 'system':
|
if role == 'system':
|
||||||
system_msgs.extend([content.strip()])
|
system_msgs.extend([msg])
|
||||||
else:
|
else:
|
||||||
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
|
chat_msgs.extend([msg])
|
||||||
|
|
||||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
# can't really truncate the system messages
|
||||||
system_msg = 'You are assistant, a large language model. Answer as concisely as possible.'
|
|
||||||
if system_msgs:
|
|
||||||
system_msg = '\n'.join(system_msgs)
|
system_msg = '\n'.join(system_msgs)
|
||||||
|
if system_msg[-1] != '\n':
|
||||||
|
system_msg = system_msg + '\n'
|
||||||
|
|
||||||
system_token_count = len(encode(system_msg)[0])
|
system_token_count = len(encode(system_msg)[0])
|
||||||
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
|
remaining_tokens = req_params['truncation_length'] - system_token_count
|
||||||
chat_msg = ''
|
chat_msg = ''
|
||||||
|
|
||||||
while chat_msgs:
|
while chat_msgs:
|
||||||
@ -279,25 +329,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chat_msg = new_msg + chat_msg
|
chat_msg = new_msg + chat_msg
|
||||||
remaining_tokens -= new_size
|
remaining_tokens -= new_size
|
||||||
else:
|
else:
|
||||||
# TODO: clip a message to fit?
|
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
|
||||||
# ie. user: ...<clipped message>
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(chat_msgs) > 0:
|
prompt = system_msg + chat_msg + role_formats['prompt']
|
||||||
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.")
|
|
||||||
|
|
||||||
if system_msg:
|
|
||||||
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant:'
|
|
||||||
else:
|
|
||||||
prompt = chat_msg + '\nassistant:'
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
# pass with some expected stop strings.
|
|
||||||
# some strange cases of "##| Instruction: " sneaking through.
|
|
||||||
stopping_strings += standard_stopping_strings
|
|
||||||
req_params['custom_stopping_strings'] = stopping_strings
|
|
||||||
else:
|
else:
|
||||||
|
# Text Completions
|
||||||
stream_object_type = 'text_completion.chunk'
|
stream_object_type = 'text_completion.chunk'
|
||||||
object_type = 'text_completion'
|
object_type = 'text_completion'
|
||||||
|
|
||||||
@ -312,9 +352,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
if token_count >= req_params['truncation_length']:
|
if token_count >= req_params['truncation_length']:
|
||||||
new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count)
|
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
||||||
prompt = prompt[-new_len:]
|
prompt = prompt[-new_len:]
|
||||||
print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.")
|
new_token_count = len(encode(prompt)[0])
|
||||||
|
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
||||||
|
token_count = new_token_count
|
||||||
|
|
||||||
|
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
|
||||||
|
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
||||||
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||||
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
||||||
|
|
||||||
# pass with some expected stop strings.
|
# pass with some expected stop strings.
|
||||||
# some strange cases of "##| Instruction: " sneaking through.
|
# some strange cases of "##| Instruction: " sneaking through.
|
||||||
@ -338,11 +385,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if stream_object_type == 'text_completion.chunk':
|
if stream_object_type == 'text_completion.chunk':
|
||||||
chunk[resp_list][0]["text"] = ""
|
chunk[resp_list][0]["text"] = ""
|
||||||
else:
|
else:
|
||||||
# This is coming back as "system" to the openapi cli, not sure why.
|
|
||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
||||||
# { "role": "assistant" }
|
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\n'
|
response = 'data: ' + json.dumps(chunk) + '\n'
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
@ -449,7 +494,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
if answer and answer[0] == ' ':
|
if answer and answer[0] == ' ':
|
||||||
answer = answer[1:]
|
answer = answer[1:]
|
||||||
print({'response': answer})
|
print({'answer': answer}, chunk)
|
||||||
return
|
return
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
# strip extra leading space off new generated content
|
||||||
|
Loading…
Reference in New Issue
Block a user