mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-30 03:18:57 +01:00
ed66ca3cdf
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
233 lines
7.4 KiB
Python
233 lines
7.4 KiB
Python
import json
|
|
import ssl
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from threading import Thread
|
|
|
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
|
from modules import shared
|
|
from modules.chat import generate_chat_reply
|
|
from modules.LoRA import add_lora_to_model
|
|
from modules.models import load_model, unload_model
|
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
|
from modules.text_generation import (
|
|
encode,
|
|
generate_reply,
|
|
stop_everything_event
|
|
)
|
|
from modules.utils import get_available_models
|
|
from modules.logging_colors import logger
|
|
|
|
|
|
def get_model_info():
|
|
return {
|
|
'model_name': shared.model_name,
|
|
'lora_names': shared.lora_names,
|
|
# dump
|
|
'shared.settings': shared.settings,
|
|
'shared.args': vars(shared.args),
|
|
}
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
if self.path == '/api/v1/model':
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
response = json.dumps({
|
|
'result': shared.model_name
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_POST(self):
|
|
content_length = int(self.headers['Content-Length'])
|
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
|
|
|
if self.path == '/api/v1/generate':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
prompt = body['prompt']
|
|
generate_params = build_parameters(body)
|
|
stopping_strings = generate_params.pop('stopping_strings')
|
|
generate_params['stream'] = False
|
|
|
|
generator = generate_reply(
|
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
|
|
answer = ''
|
|
for a in generator:
|
|
answer = a
|
|
|
|
response = json.dumps({
|
|
'results': [{
|
|
'text': answer
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/chat':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
user_input = body['user_input']
|
|
regenerate = body.get('regenerate', False)
|
|
_continue = body.get('_continue', False)
|
|
|
|
generate_params = build_parameters(body, chat=True)
|
|
generate_params['stream'] = False
|
|
|
|
generator = generate_chat_reply(
|
|
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
|
|
answer = generate_params['history']
|
|
for a in generator:
|
|
answer = a
|
|
|
|
response = json.dumps({
|
|
'results': [{
|
|
'history': answer
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/stop-stream':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
stop_everything_event()
|
|
|
|
response = json.dumps({
|
|
'results': 'success'
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/model':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
# by default return the same as the GET interface
|
|
result = shared.model_name
|
|
|
|
# Actions: info, load, list, unload
|
|
action = body.get('action', '')
|
|
|
|
if action == 'load':
|
|
model_name = body['model_name']
|
|
args = body.get('args', {})
|
|
print('args', args)
|
|
for k in args:
|
|
setattr(shared.args, k, args[k])
|
|
|
|
shared.model_name = model_name
|
|
unload_model()
|
|
|
|
model_settings = get_model_metadata(shared.model_name)
|
|
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
|
|
update_model_parameters(model_settings, initial=True)
|
|
|
|
if shared.settings['mode'] != 'instruct':
|
|
shared.settings['instruction_template'] = None
|
|
|
|
try:
|
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
if shared.args.lora:
|
|
add_lora_to_model(shared.args.lora) # list
|
|
|
|
except Exception as e:
|
|
response = json.dumps({'error': {'message': repr(e)}})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
raise e
|
|
|
|
shared.args.model = shared.model_name
|
|
|
|
result = get_model_info()
|
|
|
|
elif action == 'unload':
|
|
unload_model()
|
|
shared.model_name = None
|
|
shared.args.model = None
|
|
result = get_model_info()
|
|
|
|
elif action == 'list':
|
|
result = get_available_models()
|
|
|
|
elif action == 'info':
|
|
result = get_model_info()
|
|
|
|
response = json.dumps({
|
|
'result': result,
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/token-count':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
tokens = encode(body['prompt'])[0]
|
|
response = json.dumps({
|
|
'results': [{
|
|
'tokens': len(tokens)
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_OPTIONS(self):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
|
|
def end_headers(self):
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.send_header('Access-Control-Allow-Methods', '*')
|
|
self.send_header('Access-Control-Allow-Headers', '*')
|
|
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
|
|
super().end_headers()
|
|
|
|
|
|
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
|
server = ThreadingHTTPServer((address, port), Handler)
|
|
|
|
ssl_certfile = shared.args.ssl_certfile
|
|
ssl_keyfile = shared.args.ssl_keyfile
|
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
if ssl_verify:
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
|
|
|
def on_start(public_url: str):
|
|
logger.info(f'Starting non-streaming server at public url {public_url}/api')
|
|
|
|
if share:
|
|
try:
|
|
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
if ssl_verify:
|
|
logger.info(f'Starting API at https://{address}:{port}/api')
|
|
else:
|
|
logger.info(f'Starting API at http://{address}:{port}/api')
|
|
|
|
server.serve_forever()
|
|
|
|
|
|
def start_server(port: int, share: bool = False, tunnel_id=str):
|
|
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|