Style/pep8 improvements

This commit is contained in:
oobabooga 2023-05-02 23:05:38 -03:00
parent ecd79caa68
commit 320fcfde4e

View File

@ -1,8 +1,9 @@
import json, time, os
import json
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules import shared
from modules.text_generation import encode, generate_reply
@ -25,6 +26,8 @@ embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
val = dic.get(key, default)
if type(val) != type(default):
@ -39,6 +42,7 @@ def default(dic, key, default):
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug: print(body)
if debug:
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug:
print(body)
if '/completions' in self.path or '/generate' in self.path:
is_legacy = '/generate' in self.path
@ -112,7 +118,7 @@ class Handler(BaseHTTPRequestHandler):
resp_list = 'data' if is_legacy else 'choices'
# XXX model is ignored for now
#model = body.get('model', shared.model_name) # ignored, use existing for now
# model = body.get('model', shared.model_name) # ignored, use existing for now
model = shared.model_name
created_time = int(time.time())
cmpl_id = "conv-%d" % (created_time)
@ -143,17 +149,17 @@ class Handler(BaseHTTPRequestHandler):
'temperature': default(body, 'temperature', 1.0),
'top_p': default(body, 'top_p', 1.0),
'top_k': default(body, 'best_of', 1),
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, #(default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
'suffix': body.get('suffix', None),
'stream': default(body, 'stream', False),
'echo': default(body, 'echo', False),
#####################################################
'seed': shared.settings.get('seed', -1),
#int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
# unofficial, but it needs to get set anyways.
'truncation_length': truncation_length,
# no more args.
@ -178,7 +184,7 @@ class Handler(BaseHTTPRequestHandler):
if req_params['stream']:
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
#self.send_header('Connection', 'keep-alive')
# self.send_header('Connection', 'keep-alive')
else:
self.send_header('Content-Type', 'application/json')
self.end_headers()
@ -204,11 +210,11 @@ class Handler(BaseHTTPRequestHandler):
for m in messages:
role = m['role']
content = m['content']
#name = m.get('name', 'user')
# name = m.get('name', 'user')
if role == 'system':
system_msg += content
else:
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed?
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
@ -263,7 +269,6 @@ class Handler(BaseHTTPRequestHandler):
stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings
shared.args.no_stream = not req_params['stream']
if not shared.args.no_stream:
shared.args.chat = True
@ -286,18 +291,19 @@ class Handler(BaseHTTPRequestHandler):
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
#{ "role": "assistant" }
# { "role": "assistant" }
response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8'))
# generate reply #######################################
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
answer = ''
seen_content = ''
longest_stop_len = max([ len(x) for x in stopping_strings ])
longest_stop_len = max([len(x) for x in stopping_strings])
for a in generator:
if isinstance(a, str):
@ -356,8 +362,8 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['text'] = new_content
else:
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = { 'content': new_content }
chunk[resp_list][0]['delta'] = { 'content': new_content }
chunk[resp_list][0]['message'] = {'content': new_content}
chunk[resp_list][0]['delta'] = {'content': new_content}
response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0])
@ -382,15 +388,17 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['text'] = ''
else:
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': '' }
chunk[resp_list][0]['message'] = {'content': ''}
chunk[resp_list][0]['delta'] = {}
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
self.wfile.write(response.encode('utf-8'))
###### Finished if streaming.
if debug: print({'response': answer})
# Finished if streaming.
if debug:
print({'response': answer})
return
if debug: print({'response': answer})
if debug:
print({'response': answer})
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
@ -414,13 +422,13 @@ class Handler(BaseHTTPRequestHandler):
}
if is_chat:
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer }
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
else:
resp[resp_list][0]["text"] = answer
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model != None:
elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
@ -431,7 +439,7 @@ class Handler(BaseHTTPRequestHandler):
embeddings = embedding_model.encode(input).tolist()
data = [ {"object": "embedding", "embedding": emb, "index": n } for n, emb in enumerate(embeddings) ]
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
response = json.dumps({
"object": "list",
@ -443,7 +451,8 @@ class Handler(BaseHTTPRequestHandler):
}
})
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path:
# for now do nothing, just don't error.
@ -521,4 +530,3 @@ def run_server():
def setup():
Thread(target=run_server, daemon=True).start()