From c5af549d4b929fec0148704d29af1984ed5d9247 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 20 May 2023 18:42:17 -0300 Subject: [PATCH] Add chat API (#2233) --- api-example-chat-stream.py | 88 +++++++++++++++++++++++++++++++++ api-example-chat.py | 68 +++++++++++++++++++++++++ extensions/api/blocking_api.py | 32 ++++++++++++ extensions/api/script.py | 1 + extensions/api/streaming_api.py | 83 +++++++++++++++++++++---------- extensions/api/util.py | 36 +++++++++----- extensions/superbooga/script.py | 8 ++- modules/chat.py | 68 ++++++++++++++----------- 8 files changed, 317 insertions(+), 67 deletions(-) create mode 100644 api-example-chat-stream.py create mode 100644 api-example-chat.py diff --git a/api-example-chat-stream.py b/api-example-chat-stream.py new file mode 100644 index 00000000..fb048c60 --- /dev/null +++ b/api-example-chat-stream.py @@ -0,0 +1,88 @@ +import asyncio +import json +import sys + +try: + import websockets +except ImportError: + print("Websockets package not found. Make sure it's installed.") + +# For local streaming, the websockets are hosted without ssl - ws:// +HOST = 'localhost:5005' +URI = f'ws://{HOST}/api/v1/chat-stream' + +# For reverse-proxied streaming, the remote will likely host with ssl - wss:// +# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' + + +async def run(user_input, history): + # Note: the selected defaults change from time to time. + request = { + 'user_input': user_input, + 'history': history, + 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' + 'character': 'Example', + 'instruction_template': 'Vicuna-v1.1', + + 'regenerate': False, + '_continue': False, + 'stop_at_newline': False, + 'chat_prompt_size': 2048, + 'chat_generation_attempts': 1, + 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', + + 'max_new_tokens': 250, + 'do_sample': True, + 'temperature': 0.7, + 'top_p': 0.1, + 'typical_p': 1, + 'repetition_penalty': 1.18, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } + + async with websockets.connect(URI, ping_interval=None) as websocket: + await websocket.send(json.dumps(request)) + + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data['event']: + case 'text_stream': + yield incoming_data['history'] + case 'stream_end': + return + + +async def print_response_stream(user_input, history): + cur_len = 0 + async for new_history in run(user_input, history): + cur_message = new_history['visible'][-1][1][cur_len:] + cur_len += len(cur_message) + print(cur_message, end='') + sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. + + +if __name__ == '__main__': + user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard." + + # Basic example + history = {'internal': [], 'visible': []} + + # "Continue" example. Make sure to set '_continue' to True above + # arr = [user_input, 'Surely, here is'] + # history = {'internal': [arr], 'visible': [arr]} + + asyncio.run(print_response_stream(user_input, history)) diff --git a/api-example-chat.py b/api-example-chat.py new file mode 100644 index 00000000..7da92e6c --- /dev/null +++ b/api-example-chat.py @@ -0,0 +1,68 @@ +import json + +import requests + +# For local streaming, the websockets are hosted without ssl - http:// +HOST = 'localhost:5000' +URI = f'http://{HOST}/api/v1/chat' + +# For reverse-proxied streaming, the remote will likely host with ssl - https:// +# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' + + +def run(user_input, history): + request = { + 'user_input': user_input, + 'history': history, + 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' + 'character': 'Example', + 'instruction_template': 'Vicuna-v1.1', + + 'regenerate': False, + '_continue': False, + 'stop_at_newline': False, + 'chat_prompt_size': 2048, + 'chat_generation_attempts': 1, + 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', + + 'max_new_tokens': 250, + 'do_sample': True, + 'temperature': 0.7, + 'top_p': 0.1, + 'typical_p': 1, + 'repetition_penalty': 1.18, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } + + response = requests.post(URI, json=request) + + if response.status_code == 200: + result = response.json()['results'][0]['history'] + print(json.dumps(result, indent=4)) + print() + print(result['visible'][-1][1]) + + +if __name__ == '__main__': + user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard." + + # Basic example + history = {'internal': [], 'visible': []} + + # "Continue" example. Make sure to set '_continue' to True above + # arr = [user_input, 'Surely, here is'] + # history = {'internal': [arr], 'visible': [arr]} + + run(user_input, history) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index 134e99d4..8c2326f4 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -4,6 +4,7 @@ from threading import Thread from extensions.api.util import build_parameters, try_start_cloudflared from modules import shared +from modules.chat import generate_chat_reply from modules.text_generation import encode, generate_reply @@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler): 'text': answer }] }) + self.wfile.write(response.encode('utf-8')) + + elif self.path == '/api/v1/chat': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + user_input = body['user_input'] + history = body['history'] + regenerate = body.get('regenerate', False) + _continue = body.get('_continue', False) + + generate_params = build_parameters(body, chat=True) + generate_params['stream'] = False + + generator = generate_chat_reply( + user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) + + answer = history + for a in generator: + answer = a + + response = json.dumps({ + 'results': [{ + 'history': answer + }] + }) + + self.wfile.write(response.encode('utf-8')) + elif self.path == '/api/v1/token-count': self.send_response(200) self.send_header('Content-Type', 'application/json') @@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler): 'tokens': len(tokens) }] }) + self.wfile.write(response.encode('utf-8')) else: self.send_error(404) diff --git a/extensions/api/script.py b/extensions/api/script.py index 5544de72..5d1b1a68 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api import extensions.api.streaming_api as streaming_api from modules import shared + def setup(): blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api) streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index e50dfa22..717a8088 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -6,6 +6,7 @@ from websockets.server import serve from extensions.api.util import build_parameters, try_start_cloudflared from modules import shared +from modules.chat import generate_chat_reply from modules.text_generation import generate_reply PATH = '/api/v1/stream' @@ -13,42 +14,72 @@ PATH = '/api/v1/stream' async def _handle_connection(websocket, path): - if path != PATH: - print(f'Streaming api: unknown path: {path}') - return + if path == '/api/v1/stream': + async for message in websocket: + message = json.loads(message) - async for message in websocket: - message = json.loads(message) + prompt = message['prompt'] + generate_params = build_parameters(message) + stopping_strings = generate_params.pop('stopping_strings') + generate_params['stream'] = True - prompt = message['prompt'] - generate_params = build_parameters(message) - stopping_strings = generate_params.pop('stopping_strings') - generate_params['stream'] = True + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) - generator = generate_reply( - prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) + # As we stream, only send the new bytes. + skip_index = 0 + message_num = 0 - # As we stream, only send the new bytes. - skip_index = 0 - message_num = 0 + for a in generator: + to_send = a[skip_index:] + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': to_send + })) + + await asyncio.sleep(0) + skip_index += len(to_send) + message_num += 1 - for a in generator: - to_send = a[skip_index:] await websocket.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': to_send + 'event': 'stream_end', + 'message_num': message_num })) - await asyncio.sleep(0) + elif path == '/api/v1/chat-stream': + async for message in websocket: + body = json.loads(message) - skip_index += len(to_send) - message_num += 1 + user_input = body['user_input'] + history = body['history'] + generate_params = build_parameters(body, chat=True) + generate_params['stream'] = True + regenerate = body.get('regenerate', False) + _continue = body.get('_continue', False) - await websocket.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) + generator = generate_chat_reply( + user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) + + message_num = 0 + for a in generator: + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'history': a + })) + + await asyncio.sleep(0) + message_num += 1 + + await websocket.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + else: + print(f'Streaming api: unknown path: {path}') + return async def _run(host: str, port: int): diff --git a/extensions/api/util.py b/extensions/api/util.py index e637ac0e..369381e3 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -3,18 +3,11 @@ import traceback from threading import Thread from typing import Callable, Optional -from modules.text_generation import get_encoded_length +from modules import shared +from modules.chat import load_character -def build_parameters(body): - prompt = body['prompt'] - - prompt_lines = [k.strip() for k in prompt.split('\n')] - max_context = body.get('max_context_length', 2048) - while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context: - prompt_lines.pop(0) - - prompt = '\n'.join(prompt_lines) +def build_parameters(body, chat=False): generate_params = { 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), @@ -33,13 +26,34 @@ def build_parameters(body): 'early_stopping': bool(body.get('early_stopping', False)), 'seed': int(body.get('seed', -1)), 'add_bos_token': bool(body.get('add_bos_token', True)), - 'truncation_length': int(body.get('truncation_length', 2048)), + 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), 'ban_eos_token': bool(body.get('ban_eos_token', False)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'custom_stopping_strings': '', # leave this blank 'stopping_strings': body.get('stopping_strings', []), } + if chat: + character = body.get('character') + instruction_template = body.get('instruction_template') + name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False) + name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True) + generate_params.update({ + 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])), + 'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])), + 'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])), + 'mode': str(body.get('mode', 'chat')), + 'name1': name1, + 'name2': name2, + 'context': context, + 'greeting': greeting, + 'name1_instruct': name1_instruct, + 'name2_instruct': name2_instruct, + 'context_instruct': context_instruct, + 'turn_template': turn_template, + 'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])), + }) + return generate_params diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index a1d66add..c9f1a22d 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -4,12 +4,12 @@ import textwrap import gradio as gr from bs4 import BeautifulSoup + from modules import chat, shared from .chromadb import add_chunks_to_collector, make_collector from .download_urls import download_urls - params = { 'chunk_count': 5, 'chunk_length': 700, @@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep): data_chunks = [x for y in data_chunks for x in y] else: data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] + cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n" yield cumulative add_chunks_to_collector(data_chunks, collector) @@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): logging.warning(f'Adding the following new context:\n{additional_context}') state['context'] = state['context'].strip() + '\n' + additional_context - state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids] + kwargs['history'] = { + 'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], + 'visible': '' + } except RuntimeError: logging.error("Couldn't query the database, moving on...") diff --git a/modules/chat.py b/modules/chat.py index 3055a97a..7a21d7be 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) also_return_rows = kwargs.get('also_return_rows', False) - history = state.get('history', shared.history['internal']) + history = kwargs.get('history', shared.history)['internal'] is_instruct = state['mode'] == 'instruct' # Finding the maximum prompt size @@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs): chat_prompt_size -= shared.soft_prompt_tensor.shape[1] max_length = min(get_max_prompt_length(state), chat_prompt_size) - all_substrings = { 'chat': get_turn_substrings(state, instruct=False), 'instruct': get_turn_substrings(state, instruct=True) } + substrings = all_substrings['instruct' if is_instruct else 'chat'] # Creating the template for "chat-instruct" mode @@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state): return reply, next_character_found -def chatbot_wrapper(text, state, regenerate=False, _continue=False): +def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): + output = copy.deepcopy(history) if shared.model_name == 'None' or shared.model is None: logging.error("No model is loaded! Select one in the Model tab.") - yield shared.history['visible'] + yield output return # Defining some variables @@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): text = apply_extensions('input', text) # *Is typing...* - yield shared.history['visible'] + [[visible_text, shared.processing_message]] + if loading_message: + yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} else: - text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0] + text, visible_text = output['internal'][-1][0], output['visible'][-1][0] if regenerate: - shared.history['visible'].pop() - shared.history['internal'].pop() + output['visible'].pop() + output['internal'].pop() # *Is typing...* - yield shared.history['visible'] + [[visible_text, shared.processing_message]] + if loading_message: + yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} elif _continue: - last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] - yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']] + last_reply = [output['internal'][-1][1], output['visible'][-1][1]] + if loading_message: + yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']} # Generating the prompt - kwargs = {'_continue': _continue} + kwargs = { + '_continue': _continue, + 'history': output, + } + prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) @@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): # We need this global variable to handle the Stop event, # otherwise gradio gets confused if shared.stop_everything: - return shared.history['visible'] + yield output + return if just_started: just_started = False if not _continue: - shared.history['internal'].append(['', '']) - shared.history['visible'].append(['', '']) + output['internal'].append(['', '']) + output['visible'].append(['', '']) if _continue: - shared.history['internal'][-1] = [text, last_reply[0] + reply] - shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply] - yield shared.history['visible'] + output['internal'][-1] = [text, last_reply[0] + reply] + output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] + yield output elif not (j == 0 and visible_reply.strip() == ''): - shared.history['internal'][-1] = [text, reply] - shared.history['visible'][-1] = [visible_text, visible_reply] - yield shared.history['visible'] + output['internal'][-1] = [text, reply.lstrip(' ')] + output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] + yield output if next_character_found: break @@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): else: cumulative_reply = reply - yield shared.history['visible'] + yield output def impersonate_wrapper(text, state): @@ -291,21 +300,24 @@ def impersonate_wrapper(text, state): yield cumulative_reply -def generate_chat_reply(text, state, regenerate=False, _continue=False): +def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True): if regenerate or _continue: text = '' - if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: - yield shared.history['visible'] + if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: + yield history return - for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue): + for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): yield history # Same as above but returns HTML def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): - for history in generate_chat_reply(text, state, regenerate, _continue): - yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']) + for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)): + if i != 0: + shared.history = copy.deepcopy(history) + + yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style']) def remove_last_message():