Merge pull request #123 from oobabooga/refactor_gradio

Create new extensions engine
This commit is contained in:
oobabooga 2023-02-25 01:51:58 -03:00 committed by GitHub
commit 3ef0f2ea7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 312 additions and 235 deletions

View File

@ -134,7 +134,6 @@ Optionally, you can use the following command-line flags:
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.| | `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--picture` | Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP. |
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |

View File

@ -41,7 +41,6 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
prompt, prompt,
params['max_new_tokens'], params['max_new_tokens'],
params['do_sample'], params['do_sample'],
params['max_new_tokens'],
params['temperature'], params['temperature'],
params['top_p'], params['top_p'],
params['typical_p'], params['typical_p'],

View File

@ -1,3 +1,5 @@
import gradio as gr
params = { params = {
"bias string": " *I speak in an annoyingly cute way*", "bias string": " *I speak in an annoyingly cute way*",
} }
@ -25,3 +27,10 @@ def bot_prefix_modifier(string):
""" """
return f'{string} {params["bias string"].strip()} ' return f'{string} {params["bias string"].strip()} '
def ui():
# Gradio elements
string = gr.Textbox(value=params["bias string"], label='Character bias')
# Event functions to update the parameters in the backend
string.change(lambda x: params.update({"bias string": x}), string, None)

View File

@ -1,9 +1,12 @@
import gradio as gr
from deep_translator import GoogleTranslator from deep_translator import GoogleTranslator
params = { params = {
"language string": "ja", "language string": "ja",
} }
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@ -27,3 +30,13 @@ def bot_prefix_modifier(string):
""" """
return string return string
def ui():
# Finding the language name from the language code to use as the default value
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
# Gradio elements
language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language')
# Event functions to update the parameters in the backend
language.change(lambda x: params.update({"language string": language_codes[x]}), language, None)

View File

@ -0,0 +1,60 @@
import base64
from io import BytesIO
import gradio as gr
import modules.chat as chat
import modules.shared as shared
from modules.bot_picture import caption_image
params = {
}
# If 'state' is 'temporary' or 'permanent', will hijack the next
# chatbot wrapper call with a custom input text and optionally
# custom output text
input_hijack = {
'state': 'off',
'value': ["", ""]
}
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, visible_text
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
return string
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
return string
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
the prefix text for the Bot and can be used to bias its
behavior.
"""
return string
def ui():
picture_select = gr.Image(label='Send a picture', type='pil')
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
#parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from pathlib import Path from pathlib import Path
import gradio as gr
import torch import torch
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
@ -81,3 +82,12 @@ def bot_prefix_modifier(string):
""" """
return string return string
def ui():
# Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
voice = gr.Dropdown(value=params['speaker'], choices=[f'en_{i}' for i in range(1, 118)], label='TTS voice')
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)

View File

@ -4,19 +4,16 @@ import io
import json import json
import re import re
from datetime import datetime from datetime import datetime
from io import BytesIO
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import modules.shared as shared import modules.shared as shared
import modules.extensions as extensions_module
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length from modules.text_generation import encode, generate_reply, get_max_prompt_length
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture
# This gets the new line characters right. # This gets the new line characters right.
def clean_chat_message(text): def clean_chat_message(text):
text = text.replace('\n', '\n\n') text = text.replace('\n', '\n\n')
@ -24,16 +21,16 @@ def clean_chat_message(text):
text = text.strip() text = text.strip()
return text return text
def generate_chat_prompt(user_input, tokens, name1, name2, context, chat_prompt_size, impersonate=False): def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
user_input = clean_chat_message(user_input) user_input = clean_chat_message(user_input)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(tokens), chat_prompt_size) max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
i = len(shared.history['internal'])-1 i = len(shared.history['internal'])-1
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n") rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
@ -47,7 +44,7 @@ def generate_chat_prompt(user_input, tokens, name1, name2, context, chat_prompt_
rows.append(f"{name1}:") rows.append(f"{name1}:")
limit = 2 limit = 2
while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
@ -84,81 +81,87 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, visible_text
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
shared.stop_everything = False shared.stop_everything = False
just_started = True just_started = True
eos_token = '\n' if check else None eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
if shared.args.picture and picture is not None: # Check if any extension wants to hijack this function call
text, visible_text = generate_chat_picture(picture, name1, name2) visible_text = None
else: custom_prompt_generator = None
for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
text, visible_text = extension.input_hijack['value']
if custom_prompt_generator is None and hasattr(extension, 'custom_prompt_generator'):
custom_prompt_generator = extension.custom_prompt_generator
if visible_text is None:
visible_text = text visible_text = text
if shared.args.chat: if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>') visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size)
if custom_prompt_generator is None:
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
else:
prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size)
# Generate # Generate
for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): reply = ' '
for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply # Extracting the reply
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True) reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True)
visible_reply = apply_extensions(reply, "output") visible_reply = apply_extensions(reply, "output")
if shared.args.chat: if shared.args.chat:
visible_reply = visible_reply.replace('\n', '<br>') visible_reply = visible_reply.replace('\n', '<br>')
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
return shared.history['visible'] return shared.history['visible']
if just_started: if just_started:
just_started = False just_started = False
shared.history['internal'].append(['', '']) shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', '']) shared.history['visible'].append(['', ''])
shared.history['internal'][-1] = [text, reply] shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply] shared.history['visible'][-1] = [visible_text, visible_reply]
if not substring_found: if not substring_found:
yield shared.history['visible'] yield shared.history['visible']
if next_character_found: if next_character_found:
break break
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, generation_attempts=1):
eos_token = '\n' if check else None eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply = ' '
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) for i in range(generation_attempts):
if not substring_found: for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
yield reply reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False)
if next_character_found: if not substring_found:
break yield reply
yield reply if next_character_found:
break
yield reply
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
yield generate_chat_html(_history, name1, name2, shared.character) yield generate_chat_html(_history, name1, name2, shared.character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
if shared.character != 'None' and len(shared.history['visible']) == 1: if shared.character != 'None' and len(shared.history['visible']) == 1:
if shared.args.cai_chat: if shared.args.cai_chat:
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -168,7 +171,7 @@ def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -253,7 +256,7 @@ def tokenize_dialogue(dialogue, name1, name2):
_history.append(entry) _history.append(entry)
entry = ['', ''] entry = ['', '']
print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
for row in _history: for row in _history:
for column in row: for column in row:
print("\n") print("\n")
@ -301,8 +304,8 @@ def load_history(file, name1, name2):
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def load_default_history(name1, name2): def load_default_history(name1, name2):
if Path(f'logs/persistent.json').exists(): if Path('logs/persistent.json').exists():
load_history(open(Path(f'logs/persistent.json'), 'rb').read(), name1, name2) load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
else: else:
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
@ -370,5 +373,5 @@ def upload_tavern_character(img, name1, name2):
def upload_your_profile_picture(img): def upload_your_profile_picture(img):
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))
img.save(Path(f'img_me.png')) img.save(Path('img_me.png'))
print(f'Profile picture saved to "img_me.png"') print('Profile picture saved to "img_me.png"')

View File

@ -1,5 +1,3 @@
import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
@ -13,7 +11,7 @@ def load_extensions():
print(f'Loading the extension "{name}"... ', end='') print(f'Loading the extension "{name}"... ', end='')
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
state[name] = [True, i] state[name] = [True, i]
print(f'Ok.') print('Ok.')
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
@ -32,31 +30,15 @@ def apply_extensions(text, typ):
text = extension.bot_prefix_modifier(text) text = extension.bot_prefix_modifier(text)
return text return text
def update_extensions_parameters(*args):
i = 0
for extension, _ in iterator():
for param in extension.params:
if len(args) >= i+1:
extension.params[param] = eval(f"args[{i}]")
i += 1
def create_extensions_block(): def create_extensions_block():
extensions_ui_elements = [] # Updating the default values
default_values = []
if not (shared.args.chat or shared.args.cai_chat):
gr.Markdown('## Extensions parameters')
for extension, name in iterator(): for extension, name in iterator():
for param in extension.params: for param in extension.params:
_id = f"{name}-{param}" _id = f"{name}-{param}"
default_value = shared.settings[_id] if _id in shared.settings else extension.params[param] if _id in shared.settings:
default_values.append(default_value) extension.params[param] = shared.settings[_id]
if type(extension.params[param]) == str:
extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{name}-{param}"))
elif type(extension.params[param]) in [int, float]:
extensions_ui_elements.append(gr.Number(value=default_value, label=f"{name}-{param}"))
elif type(extension.params[param]) == bool:
extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{name}-{param}"))
update_extensions_parameters(*default_values) # Creating the extension ui elements
btn_extensions = gr.Button("Apply") for extension, name in iterator():
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) if hasattr(extension, "ui"):
extension.ui()

View File

@ -117,7 +117,7 @@ def load_model(model_name):
model = eval(command) model = eval(command)
# Loading the tokenizer # Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists(): if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/"))

View File

@ -11,6 +11,12 @@ history = {'internal': [], 'visible': []}
character = 'None' character = 'None'
stop_everything = False stop_everything = False
# UI elements (buttons, sliders, HTML, etc)
gradio = {}
# Generation input parameters
input_params = []
settings = { settings = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
@ -25,6 +31,9 @@ settings = {
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048, 'chat_prompt_size_max': 2048,
'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 5,
'preset_pygmalion': 'Pygmalion', 'preset_pygmalion': 'Pygmalion',
'name1_pygmalion': 'You', 'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii', 'name2_pygmalion': 'Kawaii',
@ -37,7 +46,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')

View File

@ -72,14 +72,14 @@ def formatted_outputs(reply, model_name):
else: else:
return reply return reply
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
original_question = question original_question = question
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
question = apply_extensions(question, "input") question = apply_extensions(question, "input")
if shared.args.verbose: if shared.args.verbose:
print(f"\n\n{question}\n--------------------\n") print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, tokens) input_ids = encode(question, max_new_tokens)
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
if not shared.args.flexgen: if not shared.args.flexgen:
n = shared.tokenizer.eos_token_id if eos_token is None else shared.tokenizer.encode(eos_token, return_tensors='pt')[0][-1] n = shared.tokenizer.eos_token_id if eos_token is None else shared.tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
@ -126,9 +126,9 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.append("synced_gpus=True") generate_params.append("synced_gpus=True")
if shared.args.no_stream: if shared.args.no_stream:
generate_params.append(f"max_new_tokens=tokens") generate_params.append("max_new_tokens=max_new_tokens")
else: else:
generate_params.append(f"max_new_tokens=8") generate_params.append("max_new_tokens=8")
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
@ -156,7 +156,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
# Generate the reply 8 tokens at a time # Generate the reply 8 tokens at a time
else: else:
yield formatted_outputs(original_question, shared.model_name) yield formatted_outputs(original_question, shared.model_name)
for i in tqdm(range(tokens//8+1)): for i in tqdm(range(max_new_tokens//8+1)):
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt: if shared.soft_prompt:

267
server.py
View File

@ -19,7 +19,7 @@ from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n") print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings # Loading custom settings
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -34,13 +34,13 @@ def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters(): def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts(): def get_available_softprompts():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
@ -100,50 +100,49 @@ def create_settings_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"): with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
with gr.Row(): with gr.Row():
do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample") shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature") shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
with gr.Row(): with gr.Row():
top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k") shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p") shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
with gr.Row(): with gr.Row():
repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty") shared.gradio['repetition_penalty'] = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size") shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
with gr.Row(): with gr.Row():
typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p") shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if shared.args.no_stream else 0, label="min_length", interactive=shared.args.no_stream) shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
gr.Markdown("Contrastive search:") gr.Markdown('Contrastive search:')
penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha") shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
gr.Markdown("Beam search (uses a lot of VRAM):") gr.Markdown('Beam search (uses a lot of VRAM):')
with gr.Row(): with gr.Row():
num_beams = gr.Slider(1, 20, step=1, value=generate_params["num_beams"], label="num_beams") shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty") shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping") shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Accordion("Soft prompt", open=False, elem_id="accordion"): with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
with gr.Row(): with gr.Row():
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
gr.Markdown('Upload a soft prompt (.zip format):') gr.Markdown('Upload a soft prompt (.zip format):')
with gr.Row(): with gr.Row():
upload_softprompt = gr.File(type='binary', file_types=[".zip"]) shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu]) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
@ -159,25 +158,24 @@ if shared.args.model is not None:
shared.model_name = shared.args.model shared.model_name = shared.args.model
else: else:
if len(available_models) == 0: if len(available_models) == 0:
print("No models are available! Please download at least one.") print('No models are available! Please download at least one.')
sys.exit(0) sys.exit(0)
elif len(available_models) == 1: elif len(available_models) == 1:
i = 0 i = 0
else: else:
print("The following models are available:\n") print('The following models are available:\n')
for i, model in enumerate(available_models): for i, model in enumerate(available_models):
print(f"{i+1}. {model}") print(f'{i+1}. {model}')
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n") print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
i = int(input())-1 i = int(input())-1
print() print()
shared.model_name = available_models[i] shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
# UI settings # UI settings
buttons = {}
gen_events = [] gen_events = []
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
default_text = shared.settings['prompt_gpt4chan'] default_text = shared.settings['prompt_gpt4chan']
elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None: elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
@ -186,176 +184,169 @@ else:
default_text = shared.settings['prompt'] default_text = shared.settings['prompt']
if shared.args.chat or shared.args.cai_chat: if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']:
interface.load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
if shared.args.cai_chat: if shared.args.cai_chat:
display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
else: else:
display = gr.Chatbot(value=shared.history['visible']) shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'])
textbox = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
buttons["Stop"] = gr.Button("Stop") shared.gradio['Stop'] = gr.Button('Stop')
buttons["Generate"] = gr.Button("Generate") shared.gradio['Generate'] = gr.Button('Generate')
buttons["Regenerate"] = gr.Button("Regenerate") shared.gradio['Regenerate'] = gr.Button('Regenerate')
with gr.Row(): with gr.Row():
buttons["Impersonate"] = gr.Button("Impersonate") shared.gradio['Impersonate'] = gr.Button('Impersonate')
buttons["Remove last"] = gr.Button("Remove last") shared.gradio['Remove last'] = gr.Button('Remove last')
buttons["Clear history"] = gr.Button("Clear history") shared.gradio['Clear history'] = gr.Button('Clear history')
with gr.Row(): with gr.Row():
buttons["Send last reply to input"] = gr.Button("Send last reply to input") shared.gradio['Send last reply to input'] = gr.Button('Send last reply to input')
buttons["Replace last reply"] = gr.Button("Replace last reply") shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
if shared.args.picture: with gr.Tab('Chat settings'):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context')
with gr.Row(): with gr.Row():
picture_select = gr.Image(label="Send a picture", type='pil') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
with gr.Tab("Chat settings"):
name1 = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
name2 = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
context = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context')
with gr.Row():
character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
with gr.Row(): with gr.Row():
check = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
with gr.Row(): with gr.Row():
with gr.Tab('Chat history'): with gr.Tab('Chat history'):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('Upload') gr.Markdown('Upload')
upload_chat_history = gr.File(type='binary', file_types=[".json", ".txt"]) shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
with gr.Column(): with gr.Column():
gr.Markdown('Download') gr.Markdown('Download')
download = gr.File() shared.gradio['download'] = gr.File()
buttons["Download"] = gr.Button(value="Click me") shared.gradio['download_button'] = gr.Button(value='Click me')
with gr.Tab('Upload character'): with gr.Tab('Upload character'):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('1. Select the JSON file') gr.Markdown('1. Select the JSON file')
upload_char = gr.File(type='binary', file_types=[".json"]) shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
with gr.Column(): with gr.Column():
gr.Markdown('2. Select your character\'s profile picture (optional)') gr.Markdown('2. Select your character\'s profile picture (optional)')
upload_img = gr.File(type='binary', file_types=["image"]) shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
buttons["Upload character"] = gr.Button(value="Submit") shared.gradio['Upload character'] = gr.Button(value='Submit')
with gr.Tab('Upload your profile picture'): with gr.Tab('Upload your profile picture'):
upload_img_me = gr.File(type='binary', file_types=["image"]) shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Upload TavernAI Character Card'): with gr.Tab('Upload TavernAI Character Card'):
upload_img_tavern = gr.File(type='binary', file_types=["image"]) shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Generation settings"): with gr.Tab('Generation settings'):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
with gr.Column(): with gr.Column():
chat_prompt_size_slider = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts')
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() create_settings_menus()
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
if shared.args.extensions is not None: if shared.args.extensions is not None:
with gr.Tab("Extensions"): with gr.Tab('Extensions'):
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
if shared.args.picture:
input_params.append(picture_select)
function_call = "chat.cai_chatbot_wrapper" if shared.args.cai_chat else "chat.chatbot_wrapper"
gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=shared.args.no_stream, api_name="textgen")) gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
if shared.args.picture: gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
picture_select.upload(eval(function_call), input_params, display, show_progress=shared.args.no_stream) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
gen_events.append(buttons["Regenerate"].click(chat.regenerate_wrapper, input_params, display, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
gen_events.append(buttons["Impersonate"].click(chat.impersonate_wrapper, input_params, textbox, show_progress=shared.args.no_stream))
buttons["Stop"].click(chat.stop_everything_event, [], [], cancels=gen_events)
buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream) shared.gradio['Send last reply to input'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display) shared.gradio['Clear history'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
buttons["Download"].click(chat.save_history, inputs=[], outputs=[download]) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
# Clearing stuff and saving the history # Clearing stuff and saving the history
for i in ["Generate", "Regenerate", "Replace last reply"]: for i in ['Generate', 'Regenerate', 'Replace last reply']:
buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
buttons[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
buttons["Clear history"].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['Clear history'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
textbox.submit(lambda x: "", textbox, textbox, show_progress=False) shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
textbox.submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
character_menu.change(chat.load_character, [character_menu, name1, name2], [name2, context, display]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
upload_chat_history.upload(chat.load_history, [upload_chat_history, name1, name2], []) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
upload_img_tavern.upload(chat.upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
upload_img_me.upload(chat.upload_your_profile_picture, [upload_img_me], []) shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
if shared.args.picture:
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
reload_inputs = [name1, name2] if shared.args.cai_chat else [] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
upload_chat_history.upload(reload_func, reload_inputs, [display]) shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
upload_img_me.upload(reload_func, reload_inputs, [display]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
interface.load(reload_func, reload_inputs, [display], show_progress=True)
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: with gr.Blocks(css=ui.css, analytics_enabled=False) as shared.gradio['interface']:
gr.Markdown(description) gr.Markdown(description)
with gr.Tab('Raw'): with gr.Tab('Raw'):
textbox = gr.Textbox(value=default_text, lines=23) shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
with gr.Tab('Markdown'): with gr.Tab('Markdown'):
markdown = gr.Markdown() shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'): with gr.Tab('HTML'):
html = gr.HTML() shared.gradio['html'] = gr.HTML()
buttons["Generate"] = gr.Button("Generate") shared.gradio['Generate'] = gr.Button('Generate')
buttons["Stop"] = gr.Button("Stop") shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
create_settings_menus()
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream)) output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
buttons["Stop"].click(None, None, None, cancels=gen_events) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
else: else:
with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: with gr.Blocks(css=ui.css, analytics_enabled=False) as shared.gradio['interface']:
gr.Markdown(description) gr.Markdown(description)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
textbox = gr.Textbox(value=default_text, lines=15, label='Input') shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
buttons["Generate"] = gr.Button("Generate") shared.gradio['Generate'] = gr.Button('Generate')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
buttons["Continue"] = gr.Button("Continue") shared.gradio['Continue'] = gr.Button('Continue')
with gr.Column(): with gr.Column():
buttons["Stop"] = gr.Button("Stop") shared.gradio['Stop'] = gr.Button('Stop')
preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() create_settings_menus()
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
with gr.Column(): with gr.Column():
with gr.Tab('Raw'): with gr.Tab('Raw'):
output_textbox = gr.Textbox(lines=15, label='Output') shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
with gr.Tab('Markdown'): with gr.Tab('Markdown'):
markdown = gr.Markdown() shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'): with gr.Tab('HTML'):
html = gr.HTML() shared.gradio['html'] = gr.HTML()
gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
buttons["Stop"].click(None, None, None, cancels=gen_events) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
interface.queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
interface.launch(prevent_thread_lock=True, share=shared.args.share, server_name="0.0.0.0", server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port)
else: else:
interface.launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
# I think that I will need this later # I think that I will need this later
while True: while True:

View File

@ -12,6 +12,9 @@
"chat_prompt_size": 2048, "chat_prompt_size": 2048,
"chat_prompt_size_min": 0, "chat_prompt_size_min": 0,
"chat_prompt_size_max": 2048, "chat_prompt_size_max": 2048,
"chat_generation_attempts": 1,
"chat_generation_attempts_min": 1,
"chat_generation_attempts_max": 5,
"preset_pygmalion": "Pygmalion", "preset_pygmalion": "Pygmalion",
"name1_pygmalion": "You", "name1_pygmalion": "You",
"name2_pygmalion": "Kawaii", "name2_pygmalion": "Kawaii",