diff --git a/README.md b/README.md index 681180ba..bea64666 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,13 @@ Optionally, you can use the following command-line flags: | `--auto-launch` | Open the web UI in the default browser upon launch. | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | +#### API + +| Flag | Description | +|---------------------------------------|-------------| +| `--api` | Enable the API extension. | +| `--public-api` | Create a public URL for the API using Cloudfare. | + Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md). ## Presets diff --git a/api-example-stream.py b/api-example-stream.py index b8e7cfb5..b299616f 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -1,39 +1,30 @@ -''' - -Contributed by SagsMug. Thank you SagsMug. -https://github.com/oobabooga/text-generation-webui/pull/175 - -''' - import asyncio import json -import random -import string +import sys -import websockets +try: + import websockets +except ImportError: + print("Websockets package not found. Make sure it's installed.") -# Gradio changes this index from time to time. To rediscover it, set VISIBLE = False in -# modules/api.py and use the dev tools to inspect the request made after clicking on the -# button called "Run" at the bottom of the UI -GRADIO_FN = 34 - - -def random_hash(): - letters = string.ascii_lowercase + string.digits - return ''.join(random.choice(letters) for i in range(9)) +# For local streaming, the websockets are hosted without ssl - ws:// +HOST = 'localhost:5005' +URI = f'ws://{HOST}/api/v1/stream' +# For reverse-proxied streaming, the remote will likely host with ssl - wss:// +# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' async def run(context): - server = "127.0.0.1" - params = { - 'max_new_tokens': 200, + # Note: the selected defaults change from time to time. + request = { + 'prompt': context, + 'max_new_tokens': 250, 'do_sample': True, - 'temperature': 0.72, - 'top_p': 0.73, + 'temperature': 1.3, + 'top_p': 0.1, 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1.0, - 'top_k': 0, + 'repetition_penalty': 1.18, + 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, @@ -45,48 +36,31 @@ async def run(context): 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, - 'stopping_strings': [], + 'stopping_strings': [] } - payload = json.dumps([context, params]) - session = random_hash() - async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: - while content := json.loads(await websocket.recv()): - # Python3.10 syntax, replace with if elif on older - match content["msg"]: - case "send_hash": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": GRADIO_FN - })) - case "estimation": - pass - case "send_data": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": GRADIO_FN, - "data": [ - payload - ] - })) - case "process_starts": - pass - case "process_generating" | "process_completed": - yield content["output"]["data"][0] - # You can search for your desired end indicator and - # stop generation by closing the websocket here - if (content["msg"] == "process_completed"): - break + async with websockets.connect(URI) as websocket: + await websocket.send(json.dumps(request)) -prompt = "What I would like to say is the following: " + yield context # Remove this if you just want to see the reply + + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data['event']: + case 'text_stream': + yield incoming_data['text'] + case 'stream_end': + return -async def get_result(): +async def print_response_stream(prompt): async for response in run(prompt): - # Print intermediate steps - print(response) + print(response, end='') + sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. - # Print final result - print(response) -asyncio.run(get_result()) +if __name__ == '__main__': + prompt = "In order to make homemade bread, follow these steps:\n1)" + asyncio.run(print_response_stream(prompt)) diff --git a/api-example.py b/api-example.py index eff610c1..4bf4f0d6 100644 --- a/api-example.py +++ b/api-example.py @@ -1,57 +1,42 @@ -''' - -This is an example on how to use the API for oobabooga/text-generation-webui. - -Make sure to start the web UI with the following flags: - -python server.py --model MODEL --listen --no-stream - -Optionally, you can also add the --share flag to generate a public gradio URL, -allowing you to use the API remotely. - -''' -import json - import requests -# Server address -server = "127.0.0.1" +# For local streaming, the websockets are hosted without ssl - http:// +HOST = 'localhost:5000' +URI = f'http://{HOST}/api/v1/generate' -# Generation parameters -# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig -params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.72, - 'top_p': 0.73, - 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1.0, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'skip_special_tokens': True, - 'stopping_strings': [], -} +# For reverse-proxied streaming, the remote will likely host with ssl - https:// +# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' -# Input prompt -prompt = "What I would like to say is the following: " +def run(context): + request = { + 'prompt': prompt, + 'max_new_tokens': 250, + 'do_sample': True, + 'temperature': 1.3, + 'top_p': 0.1, + 'typical_p': 1, + 'repetition_penalty': 1.18, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } -payload = json.dumps([prompt, params]) + response = requests.post(URI, json=request) -response = requests.post(f"http://{server}:7860/run/textgen", json={ - "data": [ - payload - ] -}).json() + if response.status_code == 200: + result = response.json()['results'][0]['text'] + print(prompt + result) -reply = response["data"][0] -print(reply) +if __name__ == '__main__': + prompt = "In order to make homemade bread, follow these steps:\n1)" + run(prompt) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py new file mode 100644 index 00000000..e66a6a50 --- /dev/null +++ b/extensions/api/blocking_api.py @@ -0,0 +1,90 @@ +import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from threading import Thread + +from modules import shared +from modules.text_generation import encode, generate_reply + +from extensions.api.util import build_parameters, try_start_cloudflared + + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == '/api/v1/model': + self.send_response(200) + self.end_headers() + response = json.dumps({ + 'result': shared.model_name + }) + + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = json.loads(self.rfile.read(content_length).decode('utf-8')) + + if self.path == '/api/v1/generate': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + prompt = body['prompt'] + generate_params = build_parameters(body) + stopping_strings = generate_params.pop('stopping_strings') + + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings) + + answer = '' + for a in generator: + if isinstance(a, str): + answer = a + else: + answer = a[0] + + response = json.dumps({ + 'results': [{ + 'text': answer if shared.is_chat() else answer[len(prompt):] + }] + }) + self.wfile.write(response.encode('utf-8')) + elif self.path == '/api/v1/token-count': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + tokens = encode(body['prompt'])[0] + response = json.dumps({ + 'results': [{ + 'tokens': len(tokens) + }] + }) + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + +def _run_server(port: int, share: bool=False): + address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + + server = ThreadingHTTPServer((address, port), Handler) + + def on_start(public_url: str): + print(f'Starting non-streaming server at public url {public_url}/api') + + if share: + try: + try_start_cloudflared(port, max_attempts=3, on_start=on_start) + except Exception: + pass + else: + print( + f'Starting API at http://{address}:{port}/api') + + server.serve_forever() + + +def start_server(port: int, share: bool = False): + Thread(target=_run_server, args=[port, share], daemon=True).start() diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt index ad788ab8..14e29d35 100644 --- a/extensions/api/requirements.txt +++ b/extensions/api/requirements.txt @@ -1 +1,2 @@ -flask_cloudflared==0.0.12 \ No newline at end of file +flask_cloudflared==0.0.12 +websockets==11.0.2 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py index e4c3a556..efeed71f 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -1,115 +1,10 @@ -import json -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from threading import Thread - +import extensions.api.blocking_api as blocking_api +import extensions.api.streaming_api as streaming_api from modules import shared -from modules.text_generation import encode, generate_reply - -params = { - 'port': 5000, -} - - -class Handler(BaseHTTPRequestHandler): - def do_GET(self): - if self.path == '/api/v1/model': - self.send_response(200) - self.end_headers() - response = json.dumps({ - 'result': shared.model_name - }) - - self.wfile.write(response.encode('utf-8')) - else: - self.send_error(404) - - def do_POST(self): - content_length = int(self.headers['Content-Length']) - body = json.loads(self.rfile.read(content_length).decode('utf-8')) - - if self.path == '/api/v1/generate': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - prompt = body['prompt'] - prompt_lines = [k.strip() for k in prompt.split('\n')] - max_context = body.get('max_context_length', 2048) - while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: - prompt_lines.pop(0) - - prompt = '\n'.join(prompt_lines) - generate_params = { - 'max_new_tokens': int(body.get('max_length', 200)), - 'do_sample': bool(body.get('do_sample', True)), - 'temperature': float(body.get('temperature', 0.5)), - 'top_p': float(body.get('top_p', 1)), - 'typical_p': float(body.get('typical', 1)), - 'repetition_penalty': float(body.get('rep_pen', 1.1)), - 'encoder_repetition_penalty': 1, - 'top_k': int(body.get('top_k', 0)), - 'min_length': int(body.get('min_length', 0)), - 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), - 'num_beams': int(body.get('num_beams', 1)), - 'penalty_alpha': float(body.get('penalty_alpha', 0)), - 'length_penalty': float(body.get('length_penalty', 1)), - 'early_stopping': bool(body.get('early_stopping', False)), - 'seed': int(body.get('seed', -1)), - 'add_bos_token': int(body.get('add_bos_token', True)), - 'truncation_length': int(body.get('truncation_length', 2048)), - 'ban_eos_token': bool(body.get('ban_eos_token', False)), - 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), - 'custom_stopping_strings': '', # leave this blank - 'stopping_strings': body.get('stopping_strings', []), - } - stopping_strings = generate_params.pop('stopping_strings') - generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings) - answer = '' - for a in generator: - if isinstance(a, str): - answer = a - else: - answer = a[0] - - response = json.dumps({ - 'results': [{ - 'text': answer if shared.is_chat() else answer[len(prompt):] - }] - }) - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/token-count': - # Not compatible with KoboldAI api - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - tokens = encode(body['prompt'])[0] - response = json.dumps({ - 'results': [{ - 'tokens': len(tokens) - }] - }) - self.wfile.write(response.encode('utf-8')) - - else: - self.send_error(404) - - -def run_server(): - server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) - server = ThreadingHTTPServer(server_addr, Handler) - if shared.args.share: - try: - from flask_cloudflared import _run_cloudflared - public_url = _run_cloudflared(params['port'], params['port'] + 1) - print(f'Starting KoboldAI compatible api at {public_url}/api') - except ImportError: - print('You should install flask_cloudflared manually') - else: - print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') - server.serve_forever() +BLOCKING_PORT = 5000 +STREAMING_PORT = 5005 def setup(): - Thread(target=run_server, daemon=True).start() + blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api) + streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py new file mode 100644 index 00000000..5ffd925b --- /dev/null +++ b/extensions/api/streaming_api.py @@ -0,0 +1,80 @@ +import json +import asyncio +from websockets.server import serve +from threading import Thread + +from modules import shared +from modules.text_generation import generate_reply + +from extensions.api.util import build_parameters, try_start_cloudflared + +PATH = '/api/v1/stream' + + +async def _handle_connection(websocket, path): + + if path != PATH: + print(f'Streaming api: unknown path: {path}') + return + + async for message in websocket: + message = json.loads(message) + + prompt = message['prompt'] + generate_params = build_parameters(message) + stopping_strings = generate_params.pop('stopping_strings') + + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings) + + # As we stream, only send the new bytes. + skip_index = len(prompt) if not shared.is_chat() else 0 + message_num = 0 + + for a in generator: + to_send = '' + if isinstance(a, str): + to_send = a[skip_index:] + else: + to_send = a[0][skip_index:] + + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': to_send + })) + + skip_index += len(to_send) + message_num += 1 + + await websocket.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + +async def _run(host: str, port: int): + async with serve(_handle_connection, host, port): + await asyncio.Future() # run forever + + +def _run_server(port: int, share: bool = False): + address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + + def on_start(public_url: str): + public_url = public_url.replace('https://', 'wss://') + print(f'Starting streaming server at public url {public_url}{PATH}') + + if share: + try: + try_start_cloudflared(port, max_attempts=3, on_start=on_start) + except Exception as e: + print(e) + else: + print(f'Starting streaming server at ws://{address}:{port}{PATH}') + + asyncio.run(_run(host=address, port=port)) + + +def start_server(port: int, share: bool = False): + Thread(target=_run_server, args=[port, share], daemon=True).start() diff --git a/extensions/api/util.py b/extensions/api/util.py new file mode 100644 index 00000000..cb9d9d06 --- /dev/null +++ b/extensions/api/util.py @@ -0,0 +1,69 @@ + +from threading import Thread +import time +from typing import Callable, Optional +from modules.text_generation import encode + + +def build_parameters(body): + prompt = body['prompt'] + + prompt_lines = [k.strip() for k in prompt.split('\n')] + max_context = body.get('max_context_length', 2048) + while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: + prompt_lines.pop(0) + + prompt = '\n'.join(prompt_lines) + + generate_params = { + 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), + 'do_sample': bool(body.get('do_sample', True)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical_p', body.get('typical', 1))), + 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), + 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), + 'top_k': int(body.get('top_k', 0)), + 'min_length': int(body.get('min_length', 0)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), + 'num_beams': int(body.get('num_beams', 1)), + 'penalty_alpha': float(body.get('penalty_alpha', 0)), + 'length_penalty': float(body.get('length_penalty', 1)), + 'early_stopping': bool(body.get('early_stopping', False)), + 'seed': int(body.get('seed', -1)), + 'add_bos_token': int(body.get('add_bos_token', True)), + 'truncation_length': int(body.get('truncation_length', 2048)), + 'ban_eos_token': bool(body.get('ban_eos_token', False)), + 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), + 'custom_stopping_strings': '', # leave this blank + 'stopping_strings': body.get('stopping_strings', []), + } + + return generate_params + + +def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): + Thread(target=_start_cloudflared, args=[ + port, max_attempts, on_start], daemon=True).start() + + +def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): + try: + from flask_cloudflared import _run_cloudflared + except ImportError: + print('You should install flask_cloudflared manually') + raise Exception( + 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.') + + for _ in range(max_attempts): + try: + public_url = _run_cloudflared(port, port + 1) + + if on_start: + on_start(public_url) + + return + except Exception: + time.sleep(3) + + raise Exception('Could not start cloudflared.') diff --git a/modules/api.py b/modules/api.py deleted file mode 100644 index 9de8e25d..00000000 --- a/modules/api.py +++ /dev/null @@ -1,52 +0,0 @@ -import json - -import gradio as gr - -from modules import shared -from modules.text_generation import generate_reply - -# set this to True to rediscover the fn_index using the browser DevTools -VISIBLE = False - - -def generate_reply_wrapper(string): - - # Provide defaults so as to not break the API on the client side when new parameters are added - generate_params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.5, - 'top_p': 1, - 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'seed': -1, - 'add_bos_token': True, - 'custom_stopping_strings': '', - 'truncation_length': 2048, - 'ban_eos_token': False, - 'skip_special_tokens': True, - 'stopping_strings': [], - } - params = json.loads(string) - generate_params.update(params[1]) - stopping_strings = generate_params.pop('stopping_strings') - for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings): - yield i - - -def create_apis(): - t1 = gr.Textbox(visible=VISIBLE) - t2 = gr.Textbox(visible=VISIBLE) - dummy = gr.Button(visible=VISIBLE) - - input_params = [t1] - output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']] - dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen') diff --git a/modules/extensions.py b/modules/extensions.py index a6903a9b..24d57f89 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -14,7 +14,8 @@ def load_extensions(): global state, setup_called for i, name in enumerate(shared.args.extensions): if name in available_extensions: - print(f'Loading the extension "{name}"... ', end='') + if name != 'api': + print(f'Loading the extension "{name}"... ', end='') try: exec(f"import extensions.{name}.script") extension = getattr(extensions, name).script @@ -22,9 +23,11 @@ def load_extensions(): setup_called.add(extension) extension.setup() state[name] = [True, i] - print('Ok.') + if name != 'api': + print('Ok.') except: - print('Fail.') + if name != 'api': + print('Fail.') traceback.print_exc() diff --git a/modules/shared.py b/modules/shared.py index 1517526a..7540e3fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -150,6 +150,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) +# API +parser.add_argument('--api', action='store_true', help='Enable the API extension.') +parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') + + args = parser.parse_args() args_defaults = parser.parse_args([]) @@ -171,6 +176,13 @@ if args.trust_remote_code: if args.share: print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n") +# Activating the API extension +if args.api or args.public_api: + if args.extensions is None: + args.extensions = ['api'] + elif 'api' not in args.extensions: + args.extensions.append('api') + def is_chat(): return args.chat diff --git a/server.py b/server.py index 2de817cb..ca44cdb5 100644 --- a/server.py +++ b/server.py @@ -40,7 +40,7 @@ import yaml from PIL import Image import modules.extensions as extensions_module -from modules import api, chat, shared, training, ui +from modules import chat, shared, training, ui from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model @@ -714,10 +714,6 @@ def create_interface(): if shared.args.extensions is not None: extensions_module.create_extensions_block() - # Create the invisible elements that define the API - if not shared.is_chat(): - api.create_apis() - # chat mode event handlers if shared.is_chat(): shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]