mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-31 14:40:15 +01:00
120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
import json
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from threading import Thread
|
|
|
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
|
from modules import shared
|
|
from modules.chat import generate_chat_reply
|
|
from modules.text_generation import encode, generate_reply
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
if self.path == '/api/v1/model':
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
response = json.dumps({
|
|
'result': shared.model_name
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_POST(self):
|
|
content_length = int(self.headers['Content-Length'])
|
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
|
|
|
if self.path == '/api/v1/generate':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
prompt = body['prompt']
|
|
generate_params = build_parameters(body)
|
|
stopping_strings = generate_params.pop('stopping_strings')
|
|
generate_params['stream'] = False
|
|
|
|
generator = generate_reply(
|
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
|
|
answer = ''
|
|
for a in generator:
|
|
answer = a
|
|
|
|
response = json.dumps({
|
|
'results': [{
|
|
'text': answer
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/chat':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
user_input = body['user_input']
|
|
history = body['history']
|
|
regenerate = body.get('regenerate', False)
|
|
_continue = body.get('_continue', False)
|
|
|
|
generate_params = build_parameters(body, chat=True)
|
|
generate_params['stream'] = False
|
|
|
|
generator = generate_chat_reply(
|
|
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
|
|
answer = history
|
|
for a in generator:
|
|
answer = a
|
|
|
|
response = json.dumps({
|
|
'results': [{
|
|
'history': answer
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
elif self.path == '/api/v1/token-count':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
tokens = encode(body['prompt'])[0]
|
|
response = json.dumps({
|
|
'results': [{
|
|
'tokens': len(tokens)
|
|
}]
|
|
})
|
|
|
|
self.wfile.write(response.encode('utf-8'))
|
|
else:
|
|
self.send_error(404)
|
|
|
|
|
|
def _run_server(port: int, share: bool = False):
|
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
|
|
|
server = ThreadingHTTPServer((address, port), Handler)
|
|
|
|
def on_start(public_url: str):
|
|
print(f'Starting non-streaming server at public url {public_url}/api')
|
|
|
|
if share:
|
|
try:
|
|
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
print(
|
|
f'Starting API at http://{address}:{port}/api')
|
|
|
|
server.serve_forever()
|
|
|
|
|
|
def start_server(port: int, share: bool = False):
|
|
Thread(target=_run_server, args=[port, share], daemon=True).start()
|