Andy Salerno 654933c634
New universal API with streaming/blocking endpoints (#990)
Previous title: Add api_streaming extension and update api-example-stream to use it

* Merge with latest main

* Add parameter capturing encoder_repetition_penalty

* Change some defaults, minor fixes

* Add --api, --public-api flags

* remove unneeded/broken comment from blocking API startup. The comment is already correctly emitted in try_start_cloudflared by calling the lambda we pass in.

* Update on_start message for blocking_api, it should say 'non-streaming' and not 'streaming'

* Update the API examples

* Change a comment

* Update README

* Remove the gradio API

* Remove unused import

* Minor change

* Remove unused import

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
2023-04-23 15:52:43 -03:00

91 lines
2.8 KiB
Python

import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules import shared
from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
response = json.dumps({
'results': [{
'text': answer if shared.is_chat() else answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def _run_server(port: int, share: bool=False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
def on_start(public_url: str):
print(f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
print(
f'Starting API at http://{address}:{port}/api')
server.serve_forever()
def start_server(port: int, share: bool = False):
Thread(target=_run_server, args=[port, share], daemon=True).start()