mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Style/pep8 improvements
This commit is contained in:
parent
ecd79caa68
commit
320fcfde4e
@ -1,8 +1,9 @@
|
||||
import json, time, os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
@ -25,6 +26,8 @@ embedding_model = None
|
||||
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
||||
|
||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
|
||||
|
||||
def default(dic, key, default):
|
||||
val = dic.get(key, default)
|
||||
if type(val) != type(default):
|
||||
@ -39,6 +42,7 @@ def default(dic, key, default):
|
||||
val = default
|
||||
return val
|
||||
|
||||
|
||||
def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
||||
|
||||
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||
|
||||
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
||||
if debug: print(body)
|
||||
if debug:
|
||||
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
||||
if debug:
|
||||
print(body)
|
||||
|
||||
if '/completions' in self.path or '/generate' in self.path:
|
||||
is_legacy = '/generate' in self.path
|
||||
@ -143,10 +149,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
'temperature': default(body, 'temperature', 1.0),
|
||||
'top_p': default(body, 'top_p', 1.0),
|
||||
'top_k': default(body, 'best_of', 1),
|
||||
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
||||
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
||||
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
|
||||
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
|
||||
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
||||
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
||||
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
|
||||
'suffix': body.get('suffix', None),
|
||||
'stream': default(body, 'stream', False),
|
||||
@ -208,7 +214,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if role == 'system':
|
||||
system_msg += content
|
||||
else:
|
||||
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed?
|
||||
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
|
||||
|
||||
system_token_count = len(encode(system_msg)[0])
|
||||
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
|
||||
@ -263,7 +269,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
stopping_strings += standard_stopping_strings
|
||||
req_params['custom_stopping_strings'] = stopping_strings
|
||||
|
||||
|
||||
shared.args.no_stream = not req_params['stream']
|
||||
if not shared.args.no_stream:
|
||||
shared.args.chat = True
|
||||
@ -292,7 +297,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
# generate reply #######################################
|
||||
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
||||
if debug:
|
||||
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
|
||||
|
||||
answer = ''
|
||||
@ -386,11 +392,13 @@ class Handler(BaseHTTPRequestHandler):
|
||||
chunk[resp_list][0]['delta'] = {}
|
||||
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
###### Finished if streaming.
|
||||
if debug: print({'response': answer})
|
||||
# Finished if streaming.
|
||||
if debug:
|
||||
print({'response': answer})
|
||||
return
|
||||
|
||||
if debug: print({'response': answer})
|
||||
if debug:
|
||||
print({'response': answer})
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
@ -420,7 +428,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
response = json.dumps(resp)
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
elif '/embeddings' in self.path and embedding_model != None:
|
||||
elif '/embeddings' in self.path and embedding_model is not None:
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
@ -443,7 +451,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
}
|
||||
})
|
||||
|
||||
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
if debug:
|
||||
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
elif '/moderations' in self.path:
|
||||
# for now do nothing, just don't error.
|
||||
@ -521,4 +530,3 @@ def run_server():
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user