Style/pep8 improvements

This commit is contained in:
oobabooga 2023-05-02 23:05:38 -03:00
parent ecd79caa68
commit 320fcfde4e

View File

@ -1,8 +1,9 @@
import json, time, os import json
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -25,6 +26,8 @@ embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
if type(val) != type(default): if type(val) != type(default):
@ -39,6 +42,7 @@ def default(dic, key, default):
val = default val = default
return val return val
def clamp(value, minvalue, maxvalue): def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8')) body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? if debug:
if debug: print(body) print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug:
print(body)
if '/completions' in self.path or '/generate' in self.path: if '/completions' in self.path or '/generate' in self.path:
is_legacy = '/generate' in self.path is_legacy = '/generate' in self.path
@ -143,10 +149,10 @@ class Handler(BaseHTTPRequestHandler):
'temperature': default(body, 'temperature', 1.0), 'temperature': default(body, 'temperature', 1.0),
'top_p': default(body, 'top_p', 1.0), 'top_p': default(body, 'top_p', 1.0),
'top_k': default(body, 'best_of', 1), 'top_k': default(body, 'best_of', 1),
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 # XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. # 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. # XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0, 'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
'suffix': body.get('suffix', None), 'suffix': body.get('suffix', None),
'stream': default(body, 'stream', False), 'stream': default(body, 'stream', False),
@ -208,7 +214,7 @@ class Handler(BaseHTTPRequestHandler):
if role == 'system': if role == 'system':
system_msg += content system_msg += content
else: else:
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed? chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
system_token_count = len(encode(system_msg)[0]) system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
@ -263,7 +269,6 @@ class Handler(BaseHTTPRequestHandler):
stopping_strings += standard_stopping_strings stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings req_params['custom_stopping_strings'] = stopping_strings
shared.args.no_stream = not req_params['stream'] shared.args.no_stream = not req_params['stream']
if not shared.args.no_stream: if not shared.args.no_stream:
shared.args.chat = True shared.args.chat = True
@ -292,7 +297,8 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# generate reply ####################################### # generate reply #######################################
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
answer = '' answer = ''
@ -386,11 +392,13 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['delta'] = {} chunk[resp_list][0]['delta'] = {}
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
###### Finished if streaming. # Finished if streaming.
if debug: print({'response': answer}) if debug:
print({'response': answer})
return return
if debug: print({'response': answer}) if debug:
print({'response': answer})
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
@ -420,7 +428,7 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model != None: elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
@ -443,7 +451,8 @@ class Handler(BaseHTTPRequestHandler):
} }
}) })
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path: elif '/moderations' in self.path:
# for now do nothing, just don't error. # for now do nothing, just don't error.
@ -521,4 +530,3 @@ def run_server():
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()