mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add chat API (#2233)
This commit is contained in:
parent
2aa01e2303
commit
c5af549d4b
88
api-example-chat-stream.py
Normal file
88
api-example-chat-stream.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
print("Websockets package not found. Make sure it's installed.")
|
||||||
|
|
||||||
|
# For local streaming, the websockets are hosted without ssl - ws://
|
||||||
|
HOST = 'localhost:5005'
|
||||||
|
URI = f'ws://{HOST}/api/v1/chat-stream'
|
||||||
|
|
||||||
|
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||||
|
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||||
|
|
||||||
|
|
||||||
|
async def run(user_input, history):
|
||||||
|
# Note: the selected defaults change from time to time.
|
||||||
|
request = {
|
||||||
|
'user_input': user_input,
|
||||||
|
'history': history,
|
||||||
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
|
'character': 'Example',
|
||||||
|
'instruction_template': 'Vicuna-v1.1',
|
||||||
|
|
||||||
|
'regenerate': False,
|
||||||
|
'_continue': False,
|
||||||
|
'stop_at_newline': False,
|
||||||
|
'chat_prompt_size': 2048,
|
||||||
|
'chat_generation_attempts': 1,
|
||||||
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||||
|
|
||||||
|
'max_new_tokens': 250,
|
||||||
|
'do_sample': True,
|
||||||
|
'temperature': 0.7,
|
||||||
|
'top_p': 0.1,
|
||||||
|
'typical_p': 1,
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
'top_k': 40,
|
||||||
|
'min_length': 0,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'early_stopping': False,
|
||||||
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'skip_special_tokens': True,
|
||||||
|
'stopping_strings': []
|
||||||
|
}
|
||||||
|
|
||||||
|
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||||
|
await websocket.send(json.dumps(request))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
incoming_data = await websocket.recv()
|
||||||
|
incoming_data = json.loads(incoming_data)
|
||||||
|
|
||||||
|
match incoming_data['event']:
|
||||||
|
case 'text_stream':
|
||||||
|
yield incoming_data['history']
|
||||||
|
case 'stream_end':
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def print_response_stream(user_input, history):
|
||||||
|
cur_len = 0
|
||||||
|
async for new_history in run(user_input, history):
|
||||||
|
cur_message = new_history['visible'][-1][1][cur_len:]
|
||||||
|
cur_len += len(cur_message)
|
||||||
|
print(cur_message, end='')
|
||||||
|
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||||
|
|
||||||
|
# Basic example
|
||||||
|
history = {'internal': [], 'visible': []}
|
||||||
|
|
||||||
|
# "Continue" example. Make sure to set '_continue' to True above
|
||||||
|
# arr = [user_input, 'Surely, here is']
|
||||||
|
# history = {'internal': [arr], 'visible': [arr]}
|
||||||
|
|
||||||
|
asyncio.run(print_response_stream(user_input, history))
|
68
api-example-chat.py
Normal file
68
api-example-chat.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# For local streaming, the websockets are hosted without ssl - http://
|
||||||
|
HOST = 'localhost:5000'
|
||||||
|
URI = f'http://{HOST}/api/v1/chat'
|
||||||
|
|
||||||
|
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||||
|
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||||
|
|
||||||
|
|
||||||
|
def run(user_input, history):
|
||||||
|
request = {
|
||||||
|
'user_input': user_input,
|
||||||
|
'history': history,
|
||||||
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
|
'character': 'Example',
|
||||||
|
'instruction_template': 'Vicuna-v1.1',
|
||||||
|
|
||||||
|
'regenerate': False,
|
||||||
|
'_continue': False,
|
||||||
|
'stop_at_newline': False,
|
||||||
|
'chat_prompt_size': 2048,
|
||||||
|
'chat_generation_attempts': 1,
|
||||||
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||||
|
|
||||||
|
'max_new_tokens': 250,
|
||||||
|
'do_sample': True,
|
||||||
|
'temperature': 0.7,
|
||||||
|
'top_p': 0.1,
|
||||||
|
'typical_p': 1,
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
'top_k': 40,
|
||||||
|
'min_length': 0,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'early_stopping': False,
|
||||||
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'skip_special_tokens': True,
|
||||||
|
'stopping_strings': []
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(URI, json=request)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()['results'][0]['history']
|
||||||
|
print(json.dumps(result, indent=4))
|
||||||
|
print()
|
||||||
|
print(result['visible'][-1][1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||||
|
|
||||||
|
# Basic example
|
||||||
|
history = {'internal': [], 'visible': []}
|
||||||
|
|
||||||
|
# "Continue" example. Make sure to set '_continue' to True above
|
||||||
|
# arr = [user_input, 'Surely, here is']
|
||||||
|
# history = {'internal': [arr], 'visible': [arr]}
|
||||||
|
|
||||||
|
run(user_input, history)
|
@ -4,6 +4,7 @@ from threading import Thread
|
|||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.chat import generate_chat_reply
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'text': answer
|
'text': answer
|
||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
elif self.path == '/api/v1/chat':
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
user_input = body['user_input']
|
||||||
|
history = body['history']
|
||||||
|
regenerate = body.get('regenerate', False)
|
||||||
|
_continue = body.get('_continue', False)
|
||||||
|
|
||||||
|
generate_params = build_parameters(body, chat=True)
|
||||||
|
generate_params['stream'] = False
|
||||||
|
|
||||||
|
generator = generate_chat_reply(
|
||||||
|
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||||
|
|
||||||
|
answer = history
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
response = json.dumps({
|
||||||
|
'results': [{
|
||||||
|
'history': answer
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
elif self.path == '/api/v1/token-count':
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'tokens': len(tokens)
|
'tokens': len(tokens)
|
||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
|
|||||||
import extensions.api.streaming_api as streaming_api
|
import extensions.api.streaming_api as streaming_api
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
|
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
|
||||||
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
|
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
|
||||||
|
@ -6,6 +6,7 @@ from websockets.server import serve
|
|||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.chat import generate_chat_reply
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
PATH = '/api/v1/stream'
|
PATH = '/api/v1/stream'
|
||||||
@ -13,42 +14,72 @@ PATH = '/api/v1/stream'
|
|||||||
|
|
||||||
async def _handle_connection(websocket, path):
|
async def _handle_connection(websocket, path):
|
||||||
|
|
||||||
if path != PATH:
|
if path == '/api/v1/stream':
|
||||||
print(f'Streaming api: unknown path: {path}')
|
async for message in websocket:
|
||||||
return
|
message = json.loads(message)
|
||||||
|
|
||||||
async for message in websocket:
|
prompt = message['prompt']
|
||||||
message = json.loads(message)
|
generate_params = build_parameters(message)
|
||||||
|
stopping_strings = generate_params.pop('stopping_strings')
|
||||||
|
generate_params['stream'] = True
|
||||||
|
|
||||||
prompt = message['prompt']
|
generator = generate_reply(
|
||||||
generate_params = build_parameters(message)
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
|
||||||
generate_params['stream'] = True
|
|
||||||
|
|
||||||
generator = generate_reply(
|
# As we stream, only send the new bytes.
|
||||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
skip_index = 0
|
||||||
|
message_num = 0
|
||||||
|
|
||||||
# As we stream, only send the new bytes.
|
for a in generator:
|
||||||
skip_index = 0
|
to_send = a[skip_index:]
|
||||||
message_num = 0
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'text': to_send
|
||||||
|
}))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
skip_index += len(to_send)
|
||||||
|
message_num += 1
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
to_send = a[skip_index:]
|
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'stream_end',
|
||||||
'message_num': message_num,
|
'message_num': message_num
|
||||||
'text': to_send
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
elif path == '/api/v1/chat-stream':
|
||||||
|
async for message in websocket:
|
||||||
|
body = json.loads(message)
|
||||||
|
|
||||||
skip_index += len(to_send)
|
user_input = body['user_input']
|
||||||
message_num += 1
|
history = body['history']
|
||||||
|
generate_params = build_parameters(body, chat=True)
|
||||||
|
generate_params['stream'] = True
|
||||||
|
regenerate = body.get('regenerate', False)
|
||||||
|
_continue = body.get('_continue', False)
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
generator = generate_chat_reply(
|
||||||
'event': 'stream_end',
|
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||||
'message_num': message_num
|
|
||||||
}))
|
message_num = 0
|
||||||
|
for a in generator:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'history': a
|
||||||
|
}))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
message_num += 1
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': message_num
|
||||||
|
}))
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f'Streaming api: unknown path: {path}')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def _run(host: str, port: int):
|
async def _run(host: str, port: int):
|
||||||
|
@ -3,18 +3,11 @@ import traceback
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from modules.text_generation import get_encoded_length
|
from modules import shared
|
||||||
|
from modules.chat import load_character
|
||||||
|
|
||||||
|
|
||||||
def build_parameters(body):
|
def build_parameters(body, chat=False):
|
||||||
prompt = body['prompt']
|
|
||||||
|
|
||||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
|
||||||
max_context = body.get('max_context_length', 2048)
|
|
||||||
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
|
|
||||||
prompt_lines.pop(0)
|
|
||||||
|
|
||||||
prompt = '\n'.join(prompt_lines)
|
|
||||||
|
|
||||||
generate_params = {
|
generate_params = {
|
||||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||||
@ -33,13 +26,34 @@ def build_parameters(body):
|
|||||||
'early_stopping': bool(body.get('early_stopping', False)),
|
'early_stopping': bool(body.get('early_stopping', False)),
|
||||||
'seed': int(body.get('seed', -1)),
|
'seed': int(body.get('seed', -1)),
|
||||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||||
'custom_stopping_strings': '', # leave this blank
|
'custom_stopping_strings': '', # leave this blank
|
||||||
'stopping_strings': body.get('stopping_strings', []),
|
'stopping_strings': body.get('stopping_strings', []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if chat:
|
||||||
|
character = body.get('character')
|
||||||
|
instruction_template = body.get('instruction_template')
|
||||||
|
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
||||||
|
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
|
||||||
|
generate_params.update({
|
||||||
|
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
||||||
|
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
||||||
|
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
|
||||||
|
'mode': str(body.get('mode', 'chat')),
|
||||||
|
'name1': name1,
|
||||||
|
'name2': name2,
|
||||||
|
'context': context,
|
||||||
|
'greeting': greeting,
|
||||||
|
'name1_instruct': name1_instruct,
|
||||||
|
'name2_instruct': name2_instruct,
|
||||||
|
'context_instruct': context_instruct,
|
||||||
|
'turn_template': turn_template,
|
||||||
|
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||||
|
})
|
||||||
|
|
||||||
return generate_params
|
return generate_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,12 +4,12 @@ import textwrap
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
|
|
||||||
from .chromadb import add_chunks_to_collector, make_collector
|
from .chromadb import add_chunks_to_collector, make_collector
|
||||||
from .download_urls import download_urls
|
from .download_urls import download_urls
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'chunk_count': 5,
|
'chunk_count': 5,
|
||||||
'chunk_length': 700,
|
'chunk_length': 700,
|
||||||
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
|||||||
data_chunks = [x for y in data_chunks for x in y]
|
data_chunks = [x for y in data_chunks for x in y]
|
||||||
else:
|
else:
|
||||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||||
|
|
||||||
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
||||||
yield cumulative
|
yield cumulative
|
||||||
add_chunks_to_collector(data_chunks, collector)
|
add_chunks_to_collector(data_chunks, collector)
|
||||||
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
|
|
||||||
logging.warning(f'Adding the following new context:\n{additional_context}')
|
logging.warning(f'Adding the following new context:\n{additional_context}')
|
||||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||||
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids]
|
kwargs['history'] = {
|
||||||
|
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||||
|
'visible': ''
|
||||||
|
}
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logging.error("Couldn't query the database, moving on...")
|
logging.error("Couldn't query the database, moving on...")
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
impersonate = kwargs.get('impersonate', False)
|
impersonate = kwargs.get('impersonate', False)
|
||||||
_continue = kwargs.get('_continue', False)
|
_continue = kwargs.get('_continue', False)
|
||||||
also_return_rows = kwargs.get('also_return_rows', False)
|
also_return_rows = kwargs.get('also_return_rows', False)
|
||||||
history = state.get('history', shared.history['internal'])
|
history = kwargs.get('history', shared.history)['internal']
|
||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||||
|
|
||||||
all_substrings = {
|
all_substrings = {
|
||||||
'chat': get_turn_substrings(state, instruct=False),
|
'chat': get_turn_substrings(state, instruct=False),
|
||||||
'instruct': get_turn_substrings(state, instruct=True)
|
'instruct': get_turn_substrings(state, instruct=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
||||||
|
|
||||||
# Creating the template for "chat-instruct" mode
|
# Creating the template for "chat-instruct" mode
|
||||||
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
|
|||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||||
|
output = copy.deepcopy(history)
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logging.error("No model is loaded! Select one in the Model tab.")
|
logging.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield shared.history['visible']
|
yield output
|
||||||
return
|
return
|
||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
|
|
||||||
text = apply_extensions('input', text)
|
text = apply_extensions('input', text)
|
||||||
# *Is typing...*
|
# *Is typing...*
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
if loading_message:
|
||||||
|
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||||
else:
|
else:
|
||||||
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
|
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||||
if regenerate:
|
if regenerate:
|
||||||
shared.history['visible'].pop()
|
output['visible'].pop()
|
||||||
shared.history['internal'].pop()
|
output['internal'].pop()
|
||||||
# *Is typing...*
|
# *Is typing...*
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
if loading_message:
|
||||||
|
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||||
elif _continue:
|
elif _continue:
|
||||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
|
last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
|
||||||
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
|
if loading_message:
|
||||||
|
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
|
||||||
|
|
||||||
# Generating the prompt
|
# Generating the prompt
|
||||||
kwargs = {'_continue': _continue}
|
kwargs = {
|
||||||
|
'_continue': _continue,
|
||||||
|
'history': output,
|
||||||
|
}
|
||||||
|
|
||||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||||
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
# otherwise gradio gets confused
|
# otherwise gradio gets confused
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
return shared.history['visible']
|
yield output
|
||||||
|
return
|
||||||
|
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
if not _continue:
|
if not _continue:
|
||||||
shared.history['internal'].append(['', ''])
|
output['internal'].append(['', ''])
|
||||||
shared.history['visible'].append(['', ''])
|
output['visible'].append(['', ''])
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
shared.history['internal'][-1] = [text, last_reply[0] + reply]
|
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||||
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||||
yield shared.history['visible']
|
yield output
|
||||||
elif not (j == 0 and visible_reply.strip() == ''):
|
elif not (j == 0 and visible_reply.strip() == ''):
|
||||||
shared.history['internal'][-1] = [text, reply]
|
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||||
yield shared.history['visible']
|
yield output
|
||||||
|
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
else:
|
else:
|
||||||
cumulative_reply = reply
|
cumulative_reply = reply
|
||||||
|
|
||||||
yield shared.history['visible']
|
yield output
|
||||||
|
|
||||||
|
|
||||||
def impersonate_wrapper(text, state):
|
def impersonate_wrapper(text, state):
|
||||||
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
|
|||||||
yield cumulative_reply
|
yield cumulative_reply
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_reply(text, state, regenerate=False, _continue=False):
|
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||||
if regenerate or _continue:
|
if regenerate or _continue:
|
||||||
text = ''
|
text = ''
|
||||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
|
||||||
yield shared.history['visible']
|
yield history
|
||||||
return
|
return
|
||||||
|
|
||||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue):
|
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
|
|
||||||
# Same as above but returns HTML
|
# Same as above but returns HTML
|
||||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||||
for history in generate_chat_reply(text, state, regenerate, _continue):
|
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
|
||||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'])
|
if i != 0:
|
||||||
|
shared.history = copy.deepcopy(history)
|
||||||
|
|
||||||
|
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message():
|
def remove_last_message():
|
||||||
|
Loading…
Reference in New Issue
Block a user