Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View File

@ -0,0 +1,88 @@
import asyncio
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/chat-stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(user_input, history):
# Note: the selected defaults change from time to time.
request = {
'user_input': user_input,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'repetition_penalty': 1.18,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['history']
case 'stream_end':
return
async def print_response_stream(user_input, history):
cur_len = 0
async for new_history in run(user_input, history):
cur_message = new_history['visible'][-1][1][cur_len:]
cur_len += len(cur_message)
print(cur_message, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
asyncio.run(print_response_stream(user_input, history))

68
api-example-chat.py Normal file
View File

@ -0,0 +1,68 @@
import json
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/chat'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(user_input, history):
request = {
'user_input': user_input,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'repetition_penalty': 1.18,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['history']
print(json.dumps(result, indent=4))
print()
print(result['visible'][-1][1])
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
run(user_input, history)

View File

@ -4,6 +4,7 @@ from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)

View File

@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api
from modules import shared
def setup():
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)

View File

@ -6,6 +6,7 @@ from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
@ -13,10 +14,7 @@ PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
if path != PATH:
print(f'Streaming api: unknown path: {path}')
return
if path == '/api/v1/stream':
async for message in websocket:
message = json.loads(message)
@ -41,7 +39,6 @@ async def _handle_connection(websocket, path):
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
@ -50,6 +47,40 @@ async def _handle_connection(websocket, path):
'message_num': message_num
}))
elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
user_input = body['user_input']
history = body['history']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):
async with serve(_handle_connection, host, port, ping_interval=None):

View File

@ -3,18 +3,11 @@ import traceback
from threading import Thread
from typing import Callable, Optional
from modules.text_generation import get_encoded_length
from modules import shared
from modules.chat import load_character
def build_parameters(body):
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
def build_parameters(body, chat=False):
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
@ -33,13 +26,34 @@ def build_parameters(body):
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template')
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
generate_params.update({
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
'mode': str(body.get('mode', 'chat')),
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'turn_template': turn_template,
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
})
return generate_params

View File

@ -4,12 +4,12 @@ import textwrap
import gradio as gr
from bs4 import BeautifulSoup
from modules import chat, shared
from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls
params = {
'chunk_count': 5,
'chunk_length': 700,
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
data_chunks = [x for y in data_chunks for x in y]
else:
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
yield cumulative
add_chunks_to_collector(data_chunks, collector)
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
logging.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids]
kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError:
logging.error("Couldn't query the database, moving on...")

View File

@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = state.get('history', shared.history['internal'])
history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct'
# Finding the maximum prompt size
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True)
}
substrings = all_substrings['instruct' if is_instruct else 'chat']
# Creating the template for "chat-instruct" mode
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
return reply, next_character_found
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
output = copy.deepcopy(history)
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield shared.history['visible']
yield output
return
# Defining some variables
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
text = apply_extensions('input', text)
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
else:
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate:
shared.history['visible'].pop()
shared.history['internal'].pop()
output['visible'].pop()
output['internal'].pop()
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
elif _continue:
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
if loading_message:
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
# Generating the prompt
kwargs = {'_continue': _continue}
kwargs = {
'_continue': _continue,
'history': output,
}
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs)
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
if shared.stop_everything:
return shared.history['visible']
yield output
return
if just_started:
just_started = False
if not _continue:
shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
output['internal'].append(['', ''])
output['visible'].append(['', ''])
if _continue:
shared.history['internal'][-1] = [text, last_reply[0] + reply]
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
yield shared.history['visible']
output['internal'][-1] = [text, last_reply[0] + reply]
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
yield output
elif not (j == 0 and visible_reply.strip() == ''):
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
yield shared.history['visible']
output['internal'][-1] = [text, reply.lstrip(' ')]
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
yield output
if next_character_found:
break
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
else:
cumulative_reply = reply
yield shared.history['visible']
yield output
def impersonate_wrapper(text, state):
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
yield cumulative_reply
def generate_chat_reply(text, state, regenerate=False, _continue=False):
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
if regenerate or _continue:
text = ''
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield shared.history['visible']
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield history
return
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue):
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history
# Same as above but returns HTML
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
for history in generate_chat_reply(text, state, regenerate, _continue):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'])
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
if i != 0:
shared.history = copy.deepcopy(history)
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
def remove_last_message():