mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-14 22:39:06 +01:00
ed66ca3cdf
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
143 lines
4.1 KiB
Python
143 lines
4.1 KiB
Python
import asyncio
|
|
import json
|
|
import ssl
|
|
from threading import Thread
|
|
|
|
from websockets.server import serve
|
|
|
|
from extensions.api.util import (
|
|
build_parameters,
|
|
try_start_cloudflared,
|
|
with_api_lock
|
|
)
|
|
from modules import shared
|
|
from modules.chat import generate_chat_reply
|
|
from modules.text_generation import generate_reply
|
|
from modules.logging_colors import logger
|
|
|
|
PATH = '/api/v1/stream'
|
|
|
|
|
|
@with_api_lock
|
|
async def _handle_stream_message(websocket, message):
|
|
message = json.loads(message)
|
|
|
|
prompt = message['prompt']
|
|
generate_params = build_parameters(message)
|
|
stopping_strings = generate_params.pop('stopping_strings')
|
|
generate_params['stream'] = True
|
|
|
|
generator = generate_reply(
|
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
|
|
# As we stream, only send the new bytes.
|
|
skip_index = 0
|
|
message_num = 0
|
|
|
|
for a in generator:
|
|
to_send = a[skip_index:]
|
|
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
|
continue
|
|
|
|
await websocket.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'text': to_send
|
|
}))
|
|
|
|
await asyncio.sleep(0)
|
|
skip_index += len(to_send)
|
|
message_num += 1
|
|
|
|
await websocket.send(json.dumps({
|
|
'event': 'stream_end',
|
|
'message_num': message_num
|
|
}))
|
|
|
|
|
|
@with_api_lock
|
|
async def _handle_chat_stream_message(websocket, message):
|
|
body = json.loads(message)
|
|
|
|
user_input = body['user_input']
|
|
generate_params = build_parameters(body, chat=True)
|
|
generate_params['stream'] = True
|
|
regenerate = body.get('regenerate', False)
|
|
_continue = body.get('_continue', False)
|
|
|
|
generator = generate_chat_reply(
|
|
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
|
|
message_num = 0
|
|
for a in generator:
|
|
await websocket.send(json.dumps({
|
|
'event': 'text_stream',
|
|
'message_num': message_num,
|
|
'history': a
|
|
}))
|
|
|
|
await asyncio.sleep(0)
|
|
message_num += 1
|
|
|
|
await websocket.send(json.dumps({
|
|
'event': 'stream_end',
|
|
'message_num': message_num
|
|
}))
|
|
|
|
|
|
async def _handle_connection(websocket, path):
|
|
|
|
if path == '/api/v1/stream':
|
|
async for message in websocket:
|
|
await _handle_stream_message(websocket, message)
|
|
|
|
elif path == '/api/v1/chat-stream':
|
|
async for message in websocket:
|
|
await _handle_chat_stream_message(websocket, message)
|
|
|
|
else:
|
|
print(f'Streaming api: unknown path: {path}')
|
|
return
|
|
|
|
|
|
async def _run(host: str, port: int):
|
|
ssl_certfile = shared.args.ssl_certfile
|
|
ssl_keyfile = shared.args.ssl_keyfile
|
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
if ssl_verify:
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
|
else:
|
|
context = None
|
|
|
|
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
|
|
await asyncio.Future() # Run the server forever
|
|
|
|
|
|
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
|
ssl_certfile = shared.args.ssl_certfile
|
|
ssl_keyfile = shared.args.ssl_keyfile
|
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
|
|
def on_start(public_url: str):
|
|
public_url = public_url.replace('https://', 'wss://')
|
|
logger.info(f'Starting streaming server at public url {public_url}{PATH}')
|
|
|
|
if share:
|
|
try:
|
|
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
|
except Exception as e:
|
|
print(e)
|
|
else:
|
|
if ssl_verify:
|
|
logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}')
|
|
else:
|
|
logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}')
|
|
|
|
asyncio.run(_run(host=address, port=port))
|
|
|
|
|
|
def start_server(port: int, share: bool = False, tunnel_id=str):
|
|
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|