diff --git a/extensions/api/script.py b/extensions/api/script.py new file mode 100644 index 00000000..3dbf6368 --- /dev/null +++ b/extensions/api/script.py @@ -0,0 +1,82 @@ +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from threading import Thread +from modules import shared +from modules.text_generation import generate_reply, encode +import json + +params = { + 'port': 5000, +} + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == '/api/v1/model': + self.send_response(200) + self.end_headers() + response = json.dumps({ + 'result': shared.model_name + }) + + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = json.loads(self.rfile.read(content_length).decode('utf-8')) + + if self.path == '/api/v1/generate': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + prompt = body['prompt'] + prompt_lines = [l.strip() for l in prompt.split('\n')] + + max_context = body.get('max_context_length', 2048) + + while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: + prompt_lines.pop(0) + + prompt = '\n'.join(prompt_lines) + + generator = generate_reply( + question = prompt, + max_new_tokens = body.get('max_length', 200), + do_sample=True, + temperature=body.get('temperature', 0.5), + top_p=body.get('top_p', 1), + typical_p=body.get('typical', 1), + repetition_penalty=body.get('rep_pen', 1.1), + encoder_repetition_penalty=1, + top_k=body.get('top_k', 0), + min_length=0, + no_repeat_ngram_size=0, + num_beams=1, + penalty_alpha=0, + length_penalty=1, + early_stopping=False, + ) + + answer = '' + for a in generator: + answer = a[0] + + response = json.dumps({ + 'results': [{ + 'text': answer[len(prompt):] + }] + }) + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + +def run_server(): + server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) + server = ThreadingHTTPServer(server_addr, Handler) + print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') + server.serve_forever() + +def ui(): + Thread(target=run_server, daemon=True).start() \ No newline at end of file