Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View File

@ -0,0 +1,88 @@
import asyncio
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/chat-stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(user_input, history):
# Note: the selected defaults change from time to time.
request = {
'user_input': user_input,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'repetition_penalty': 1.18,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['history']
case 'stream_end':
return
async def print_response_stream(user_input, history):
cur_len = 0
async for new_history in run(user_input, history):
cur_message = new_history['visible'][-1][1][cur_len:]
cur_len += len(cur_message)
print(cur_message, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
asyncio.run(print_response_stream(user_input, history))

68
api-example-chat.py Normal file
View File

@ -0,0 +1,68 @@
import json
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/chat'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(user_input, history):
request = {
'user_input': user_input,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'repetition_penalty': 1.18,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['history']
print(json.dumps(result, indent=4))
print()
print(result['visible'][-1][1])
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
run(user_input, history)

View File

@ -4,6 +4,7 @@ from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
'text': answer 'text': answer
}] }]
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count': elif self.path == '/api/v1/token-count':
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
'tokens': len(tokens) 'tokens': len(tokens)
}] }]
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
else: else:
self.send_error(404) self.send_error(404)

View File

@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api import extensions.api.streaming_api as streaming_api
from modules import shared from modules import shared
def setup(): def setup():
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api) blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api) streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)

View File

@ -6,6 +6,7 @@ from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
PATH = '/api/v1/stream' PATH = '/api/v1/stream'
@ -13,42 +14,72 @@ PATH = '/api/v1/stream'
async def _handle_connection(websocket, path): async def _handle_connection(websocket, path):
if path != PATH: if path == '/api/v1/stream':
print(f'Streaming api: unknown path: {path}') async for message in websocket:
return message = json.loads(message)
async for message in websocket: prompt = message['prompt']
message = json.loads(message) generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
prompt = message['prompt'] generator = generate_reply(
generate_params = build_parameters(message) prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply( # As we stream, only send the new bytes.
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) skip_index = 0
message_num = 0
# As we stream, only send the new bytes. for a in generator:
skip_index = 0 to_send = a[skip_index:]
message_num = 0 await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
for a in generator:
to_send = a[skip_index:]
await websocket.send(json.dumps({ await websocket.send(json.dumps({
'event': 'text_stream', 'event': 'stream_end',
'message_num': message_num, 'message_num': message_num
'text': to_send
})) }))
await asyncio.sleep(0) elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
skip_index += len(to_send) user_input = body['user_input']
message_num += 1 history = body['history']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
await websocket.send(json.dumps({ generator = generate_chat_reply(
'event': 'stream_end', user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
'message_num': message_num
})) message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int): async def _run(host: str, port: int):

View File

@ -3,18 +3,11 @@ import traceback
from threading import Thread from threading import Thread
from typing import Callable, Optional from typing import Callable, Optional
from modules.text_generation import get_encoded_length from modules import shared
from modules.chat import load_character
def build_parameters(body): def build_parameters(body, chat=False):
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = { generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
@ -33,13 +26,34 @@ def build_parameters(body):
'early_stopping': bool(body.get('early_stopping', False)), 'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)), 'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)), 'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)), 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'ban_eos_token': bool(body.get('ban_eos_token', False)), 'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank 'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []), 'stopping_strings': body.get('stopping_strings', []),
} }
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template')
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
generate_params.update({
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
'mode': str(body.get('mode', 'chat')),
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'turn_template': turn_template,
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
})
return generate_params return generate_params

View File

@ -4,12 +4,12 @@ import textwrap
import gradio as gr import gradio as gr
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from modules import chat, shared from modules import chat, shared
from .chromadb import add_chunks_to_collector, make_collector from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls from .download_urls import download_urls
params = { params = {
'chunk_count': 5, 'chunk_count': 5,
'chunk_length': 700, 'chunk_length': 700,
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
data_chunks = [x for y in data_chunks for x in y] data_chunks = [x for y in data_chunks for x in y]
else: else:
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n" cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
yield cumulative yield cumulative
add_chunks_to_collector(data_chunks, collector) add_chunks_to_collector(data_chunks, collector)
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
logging.warning(f'Adding the following new context:\n{additional_context}') logging.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context state['context'] = state['context'].strip() + '\n' + additional_context
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids] kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError: except RuntimeError:
logging.error("Couldn't query the database, moving on...") logging.error("Couldn't query the database, moving on...")

View File

@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False) impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False) _continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False) also_return_rows = kwargs.get('also_return_rows', False)
history = state.get('history', shared.history['internal']) history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Finding the maximum prompt size # Finding the maximum prompt size
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size) max_length = min(get_max_prompt_length(state), chat_prompt_size)
all_substrings = { all_substrings = {
'chat': get_turn_substrings(state, instruct=False), 'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True) 'instruct': get_turn_substrings(state, instruct=True)
} }
substrings = all_substrings['instruct' if is_instruct else 'chat'] substrings = all_substrings['instruct' if is_instruct else 'chat']
# Creating the template for "chat-instruct" mode # Creating the template for "chat-instruct" mode
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, state, regenerate=False, _continue=False): def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
output = copy.deepcopy(history)
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.") logging.error("No model is loaded! Select one in the Model tab.")
yield shared.history['visible'] yield output
return return
# Defining some variables # Defining some variables
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
text = apply_extensions('input', text) text = apply_extensions('input', text)
# *Is typing...* # *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]] if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
else: else:
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0] text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate: if regenerate:
shared.history['visible'].pop() output['visible'].pop()
shared.history['internal'].pop() output['internal'].pop()
# *Is typing...* # *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]] if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
elif _continue: elif _continue:
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']] if loading_message:
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
# Generating the prompt # Generating the prompt
kwargs = {'_continue': _continue} kwargs = {
'_continue': _continue,
'history': output,
}
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None: if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs) prompt = generate_chat_prompt(text, state, **kwargs)
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
return shared.history['visible'] yield output
return
if just_started: if just_started:
just_started = False just_started = False
if not _continue: if not _continue:
shared.history['internal'].append(['', '']) output['internal'].append(['', ''])
shared.history['visible'].append(['', '']) output['visible'].append(['', ''])
if _continue: if _continue:
shared.history['internal'][-1] = [text, last_reply[0] + reply] output['internal'][-1] = [text, last_reply[0] + reply]
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
yield shared.history['visible'] yield output
elif not (j == 0 and visible_reply.strip() == ''): elif not (j == 0 and visible_reply.strip() == ''):
shared.history['internal'][-1] = [text, reply] output['internal'][-1] = [text, reply.lstrip(' ')]
shared.history['visible'][-1] = [visible_text, visible_reply] output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
yield shared.history['visible'] yield output
if next_character_found: if next_character_found:
break break
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
else: else:
cumulative_reply = reply cumulative_reply = reply
yield shared.history['visible'] yield output
def impersonate_wrapper(text, state): def impersonate_wrapper(text, state):
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
yield cumulative_reply yield cumulative_reply
def generate_chat_reply(text, state, regenerate=False, _continue=False): def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
if regenerate or _continue: if regenerate or _continue:
text = '' text = ''
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield shared.history['visible'] yield history
return return
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue): for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history yield history
# Same as above but returns HTML # Same as above but returns HTML
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
for history in generate_chat_reply(text, state, regenerate, _continue): for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']) if i != 0:
shared.history = copy.deepcopy(history)
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
def remove_last_message(): def remove_last_message():