mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Add chat API (#2233)
This commit is contained in:
parent
2aa01e2303
commit
c5af549d4b
88
api-example-chat-stream.py
Normal file
88
api-example-chat-stream.py
Normal file
@ -0,0 +1,88 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5005'
|
||||
URI = f'ws://{HOST}/api/v1/chat-stream'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(user_input, history):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
'instruction_template': 'Vicuna-v1.1',
|
||||
|
||||
'regenerate': False,
|
||||
'_continue': False,
|
||||
'stop_at_newline': False,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_generation_attempts': 1,
|
||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
'max_new_tokens': 250,
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['history']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def print_response_stream(user_input, history):
|
||||
cur_len = 0
|
||||
async for new_history in run(user_input, history):
|
||||
cur_message = new_history['visible'][-1][1][cur_len:]
|
||||
cur_len += len(cur_message)
|
||||
print(cur_message, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||
|
||||
# Basic example
|
||||
history = {'internal': [], 'visible': []}
|
||||
|
||||
# "Continue" example. Make sure to set '_continue' to True above
|
||||
# arr = [user_input, 'Surely, here is']
|
||||
# history = {'internal': [arr], 'visible': [arr]}
|
||||
|
||||
asyncio.run(print_response_stream(user_input, history))
|
68
api-example-chat.py
Normal file
68
api-example-chat.py
Normal file
@ -0,0 +1,68 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - http://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'http://{HOST}/api/v1/chat'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||
|
||||
|
||||
def run(user_input, history):
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
'instruction_template': 'Vicuna-v1.1',
|
||||
|
||||
'regenerate': False,
|
||||
'_continue': False,
|
||||
'stop_at_newline': False,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_generation_attempts': 1,
|
||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
'max_new_tokens': 250,
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
response = requests.post(URI, json=request)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()['results'][0]['history']
|
||||
print(json.dumps(result, indent=4))
|
||||
print()
|
||||
print(result['visible'][-1][1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||
|
||||
# Basic example
|
||||
history = {'internal': [], 'visible': []}
|
||||
|
||||
# "Continue" example. Make sure to set '_continue' to True above
|
||||
# arr = [user_input, 'Surely, here is']
|
||||
# history = {'internal': [arr], 'visible': [arr]}
|
||||
|
||||
run(user_input, history)
|
@ -4,6 +4,7 @@ from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
|
||||
'text': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/chat':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
user_input = body['user_input']
|
||||
history = body['history']
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
answer = history
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'history': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
|
||||
import extensions.api.streaming_api as streaming_api
|
||||
from modules import shared
|
||||
|
||||
|
||||
def setup():
|
||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
|
||||
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
|
||||
|
@ -6,6 +6,7 @@ from websockets.server import serve
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
@ -13,10 +14,7 @@ PATH = '/api/v1/stream'
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
|
||||
if path != PATH:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
|
||||
if path == '/api/v1/stream':
|
||||
async for message in websocket:
|
||||
message = json.loads(message)
|
||||
|
||||
@ -41,7 +39,6 @@ async def _handle_connection(websocket, path):
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
@ -50,6 +47,40 @@ async def _handle_connection(websocket, path):
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
elif path == '/api/v1/chat-stream':
|
||||
async for message in websocket:
|
||||
body = json.loads(message)
|
||||
|
||||
user_input = body['user_input']
|
||||
history = body['history']
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = True
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'history': a
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
else:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
async with serve(_handle_connection, host, port, ping_interval=None):
|
||||
|
@ -3,18 +3,11 @@ import traceback
|
||||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from modules.text_generation import get_encoded_length
|
||||
from modules import shared
|
||||
from modules.chat import load_character
|
||||
|
||||
|
||||
def build_parameters(body):
|
||||
prompt = body['prompt']
|
||||
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
def build_parameters(body, chat=False):
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
@ -33,13 +26,34 @@ def build_parameters(body):
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
if chat:
|
||||
character = body.get('character')
|
||||
instruction_template = body.get('instruction_template')
|
||||
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
|
||||
generate_params.update({
|
||||
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
||||
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
||||
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
|
||||
'mode': str(body.get('mode', 'chat')),
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': context_instruct,
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||
})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
|
@ -4,12 +4,12 @@ import textwrap
|
||||
|
||||
import gradio as gr
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
from .chromadb import add_chunks_to_collector, make_collector
|
||||
from .download_urls import download_urls
|
||||
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_length': 700,
|
||||
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
||||
data_chunks = [x for y in data_chunks for x in y]
|
||||
else:
|
||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||
|
||||
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
||||
yield cumulative
|
||||
add_chunks_to_collector(data_chunks, collector)
|
||||
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
logging.warning(f'Adding the following new context:\n{additional_context}')
|
||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids]
|
||||
kwargs['history'] = {
|
||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'visible': ''
|
||||
}
|
||||
except RuntimeError:
|
||||
logging.error("Couldn't query the database, moving on...")
|
||||
|
||||
|
@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
also_return_rows = kwargs.get('also_return_rows', False)
|
||||
history = state.get('history', shared.history['internal'])
|
||||
history = kwargs.get('history', shared.history)['internal']
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
|
||||
# Finding the maximum prompt size
|
||||
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
all_substrings = {
|
||||
'chat': get_turn_substrings(state, instruct=False),
|
||||
'instruct': get_turn_substrings(state, instruct=True)
|
||||
}
|
||||
|
||||
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
||||
|
||||
# Creating the template for "chat-instruct" mode
|
||||
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||
output = copy.deepcopy(history)
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield shared.history['visible']
|
||||
yield output
|
||||
return
|
||||
|
||||
# Defining some variables
|
||||
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
|
||||
text = apply_extensions('input', text)
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||
else:
|
||||
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
|
||||
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||
if regenerate:
|
||||
shared.history['visible'].pop()
|
||||
shared.history['internal'].pop()
|
||||
output['visible'].pop()
|
||||
output['internal'].pop()
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||
elif _continue:
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
|
||||
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
|
||||
last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
kwargs = {
|
||||
'_continue': _continue,
|
||||
'history': output,
|
||||
}
|
||||
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
if shared.stop_everything:
|
||||
return shared.history['visible']
|
||||
yield output
|
||||
return
|
||||
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
output['internal'].append(['', ''])
|
||||
output['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
shared.history['internal'][-1] = [text, last_reply[0] + reply]
|
||||
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
yield shared.history['visible']
|
||||
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
yield output
|
||||
elif not (j == 0 and visible_reply.strip() == ''):
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
yield shared.history['visible']
|
||||
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
yield output
|
||||
|
||||
if next_character_found:
|
||||
break
|
||||
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
else:
|
||||
cumulative_reply = reply
|
||||
|
||||
yield shared.history['visible']
|
||||
yield output
|
||||
|
||||
|
||||
def impersonate_wrapper(text, state):
|
||||
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
|
||||
yield cumulative_reply
|
||||
|
||||
|
||||
def generate_chat_reply(text, state, regenerate=False, _continue=False):
|
||||
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||
if regenerate or _continue:
|
||||
text = ''
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield shared.history['visible']
|
||||
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
|
||||
yield history
|
||||
return
|
||||
|
||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue):
|
||||
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
|
||||
yield history
|
||||
|
||||
|
||||
# Same as above but returns HTML
|
||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||
for history in generate_chat_reply(text, state, regenerate, _continue):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'])
|
||||
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
|
||||
if i != 0:
|
||||
shared.history = copy.deepcopy(history)
|
||||
|
||||
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
|
||||
|
||||
|
||||
def remove_last_message():
|
||||
|
Loading…
Reference in New Issue
Block a user