diff --git a/README.md b/README.md index cd5cacc5..5af0f139 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,6 @@ Optionally, you can use the following command-line flags: | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--chat` | Launch the web UI in chat mode.| | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | -| `--picture` | Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP. | | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | diff --git a/api-example.py b/api-example.py index f2f6c51e..0306b7ab 100644 --- a/api-example.py +++ b/api-example.py @@ -41,7 +41,6 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={ prompt, params['max_new_tokens'], params['do_sample'], - params['max_new_tokens'], params['temperature'], params['top_p'], params['typical_p'], diff --git a/extensions/character_bias/script.py b/extensions/character_bias/script.py index 9660a59a..aa949f29 100644 --- a/extensions/character_bias/script.py +++ b/extensions/character_bias/script.py @@ -1,3 +1,5 @@ +import gradio as gr + params = { "bias string": " *I speak in an annoyingly cute way*", } @@ -25,3 +27,10 @@ def bot_prefix_modifier(string): """ return f'{string} {params["bias string"].strip()} ' + +def ui(): + # Gradio elements + string = gr.Textbox(value=params["bias string"], label='Character bias') + + # Event functions to update the parameters in the backend + string.change(lambda x: params.update({"bias string": x}), string, None) diff --git a/extensions/google_translate/script.py b/extensions/google_translate/script.py index 064a7ec7..68bc54b2 100644 --- a/extensions/google_translate/script.py +++ b/extensions/google_translate/script.py @@ -1,9 +1,12 @@ +import gradio as gr from deep_translator import GoogleTranslator params = { "language string": "ja", } +language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'} + def input_modifier(string): """ This function is applied to your text inputs before @@ -27,3 +30,13 @@ def bot_prefix_modifier(string): """ return string + +def ui(): + # Finding the language name from the language code to use as the default value + language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])] + + # Gradio elements + language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language') + + # Event functions to update the parameters in the backend + language.change(lambda x: params.update({"language string": language_codes[x]}), language, None) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py new file mode 100644 index 00000000..fe4f083d --- /dev/null +++ b/extensions/send_pictures/script.py @@ -0,0 +1,60 @@ +import base64 +from io import BytesIO + +import gradio as gr + +import modules.chat as chat +import modules.shared as shared +from modules.bot_picture import caption_image + +params = { +} + +# If 'state' is 'temporary' or 'permanent', will hijack the next +# chatbot wrapper call with a custom input text and optionally +# custom output text +input_hijack = { + 'state': 'off', + 'value': ["", ""] +} + +def generate_chat_picture(picture, name1, name2): + text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' + buffer = BytesIO() + picture.save(buffer, format="JPEG") + img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') + visible_text = f'' + return text, visible_text + +def input_modifier(string): + """ + This function is applied to your text inputs before + they are fed into the model. + """ + + return string + +def output_modifier(string): + """ + This function is applied to the model outputs. + """ + + return string + +def bot_prefix_modifier(string): + """ + This function is only applied in chat mode. It modifies + the prefix text for the Bot and can be used to bias its + behavior. + """ + + return string + +def ui(): + picture_select = gr.Image(label='Send a picture', type='pil') + + function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' + picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) + picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) + picture_select.upload(lambda : None, [], [picture_select], show_progress=False) + #parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.') diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 3328422b..6c0c4ef6 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -1,6 +1,7 @@ import asyncio from pathlib import Path +import gradio as gr import torch torch._C._jit_set_profiling_mode(False) @@ -81,3 +82,12 @@ def bot_prefix_modifier(string): """ return string + +def ui(): + # Gradio elements + activate = gr.Checkbox(value=params['activate'], label='Activate TTS') + voice = gr.Dropdown(value=params['speaker'], choices=[f'en_{i}' for i in range(1, 118)], label='TTS voice') + + # Event functions to update the parameters in the backend + activate.change(lambda x: params.update({"activate": x}), activate, None) + voice.change(lambda x: params.update({"speaker": x}), voice, None) diff --git a/modules/chat.py b/modules/chat.py index 4e094329..cbcf1881 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -4,19 +4,16 @@ import io import json import re from datetime import datetime -from io import BytesIO from pathlib import Path from PIL import Image import modules.shared as shared +import modules.extensions as extensions_module from modules.extensions import apply_extensions from modules.html_generator import generate_chat_html from modules.text_generation import encode, generate_reply, get_max_prompt_length -if shared.args.picture and (shared.args.cai_chat or shared.args.chat): - import modules.bot_picture as bot_picture - # This gets the new line characters right. def clean_chat_message(text): text = text.replace('\n', '\n\n') @@ -24,16 +21,16 @@ def clean_chat_message(text): text = text.strip() return text -def generate_chat_prompt(user_input, tokens, name1, name2, context, chat_prompt_size, impersonate=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): user_input = clean_chat_message(user_input) rows = [f"{context.strip()}\n"] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] - max_length = min(get_max_prompt_length(tokens), chat_prompt_size) + max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) i = len(shared.history['internal'])-1 - while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: + while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n") @@ -47,7 +44,7 @@ def generate_chat_prompt(user_input, tokens, name1, name2, context, chat_prompt_ rows.append(f"{name1}:") limit = 2 - while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: + while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) prompt = ''.join(rows) @@ -84,81 +81,87 @@ def extract_message_from_reply(question, reply, current, other, check, extension return reply, next_character_found, substring_found -def generate_chat_picture(picture, name1, name2): - text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*' - buffer = BytesIO() - picture.save(buffer, format="JPEG") - img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') - visible_text = f'' - return text, visible_text - def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): shared.stop_everything = False just_started = True eos_token = '\n' if check else None - if 'pygmalion' in shared.model_name.lower(): name1 = "You" - if shared.args.picture and picture is not None: - text, visible_text = generate_chat_picture(picture, name1, name2) - else: + # Check if any extension wants to hijack this function call + visible_text = None + custom_prompt_generator = None + for extension, _ in extensions_module.iterator(): + if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: + text, visible_text = extension.input_hijack['value'] + if custom_prompt_generator is None and hasattr(extension, 'custom_prompt_generator'): + custom_prompt_generator = extension.custom_prompt_generator + + if visible_text is None: visible_text = text if shared.args.chat: visible_text = visible_text.replace('\n', '
') - text = apply_extensions(text, "input") - prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size) + text = apply_extensions(text, "input") + + if custom_prompt_generator is None: + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + else: + prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size) # Generate - for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): + reply = ' ' + for i in range(chat_generation_attempts): + for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): - # Extracting the reply - reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True) - visible_reply = apply_extensions(reply, "output") - if shared.args.chat: - visible_reply = visible_reply.replace('\n', '
') + # Extracting the reply + reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True) + visible_reply = apply_extensions(reply, "output") + if shared.args.chat: + visible_reply = visible_reply.replace('\n', '
') - # We need this global variable to handle the Stop event, - # otherwise gradio gets confused - if shared.stop_everything: - return shared.history['visible'] - if just_started: - just_started = False - shared.history['internal'].append(['', '']) - shared.history['visible'].append(['', '']) + # We need this global variable to handle the Stop event, + # otherwise gradio gets confused + if shared.stop_everything: + return shared.history['visible'] + if just_started: + just_started = False + shared.history['internal'].append(['', '']) + shared.history['visible'].append(['', '']) - shared.history['internal'][-1] = [text, reply] - shared.history['visible'][-1] = [visible_text, visible_reply] - if not substring_found: - yield shared.history['visible'] - if next_character_found: - break - yield shared.history['visible'] + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] + if not substring_found: + yield shared.history['visible'] + if next_character_found: + break + yield shared.history['visible'] -def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, generation_attempts=1): eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) - for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): - reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) - if not substring_found: - yield reply - if next_character_found: - break - yield reply + reply = ' ' + for i in range(generation_attempts): + for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) + if not substring_found: + yield reply + if next_character_found: + break + yield reply -def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): if shared.character != 'None' and len(shared.history['visible']) == 1: if shared.args.cai_chat: yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -168,7 +171,7 @@ def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -253,7 +256,7 @@ def tokenize_dialogue(dialogue, name1, name2): _history.append(entry) entry = ['', ''] - print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') + print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') for row in _history: for column in row: print("\n") @@ -301,8 +304,8 @@ def load_history(file, name1, name2): shared.history['visible'] = copy.deepcopy(shared.history['internal']) def load_default_history(name1, name2): - if Path(f'logs/persistent.json').exists(): - load_history(open(Path(f'logs/persistent.json'), 'rb').read(), name1, name2) + if Path('logs/persistent.json').exists(): + load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2) else: shared.history['internal'] = [] shared.history['visible'] = [] @@ -370,5 +373,5 @@ def upload_tavern_character(img, name1, name2): def upload_your_profile_picture(img): img = Image.open(io.BytesIO(img)) - img.save(Path(f'img_me.png')) - print(f'Profile picture saved to "img_me.png"') + img.save(Path('img_me.png')) + print('Profile picture saved to "img_me.png"') diff --git a/modules/extensions.py b/modules/extensions.py index 17d9a381..c0da496a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,3 @@ -import gradio as gr - import extensions import modules.shared as shared @@ -13,7 +11,7 @@ def load_extensions(): print(f'Loading the extension "{name}"... ', end='') exec(f"import extensions.{name}.script") state[name] = [True, i] - print(f'Ok.') + print('Ok.') # This iterator returns the extensions in the order specified in the command-line def iterator(): @@ -32,31 +30,15 @@ def apply_extensions(text, typ): text = extension.bot_prefix_modifier(text) return text -def update_extensions_parameters(*args): - i = 0 - for extension, _ in iterator(): - for param in extension.params: - if len(args) >= i+1: - extension.params[param] = eval(f"args[{i}]") - i += 1 - def create_extensions_block(): - extensions_ui_elements = [] - default_values = [] - if not (shared.args.chat or shared.args.cai_chat): - gr.Markdown('## Extensions parameters') + # Updating the default values for extension, name in iterator(): for param in extension.params: _id = f"{name}-{param}" - default_value = shared.settings[_id] if _id in shared.settings else extension.params[param] - default_values.append(default_value) - if type(extension.params[param]) == str: - extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{name}-{param}")) - elif type(extension.params[param]) in [int, float]: - extensions_ui_elements.append(gr.Number(value=default_value, label=f"{name}-{param}")) - elif type(extension.params[param]) == bool: - extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{name}-{param}")) + if _id in shared.settings: + extension.params[param] = shared.settings[_id] - update_extensions_parameters(*default_values) - btn_extensions = gr.Button("Apply") - btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) + # Creating the extension ui elements + for extension, name in iterator(): + if hasattr(extension, "ui"): + extension.ui() diff --git a/modules/models.py b/modules/models.py index 37f7dfd8..0cb9ae6e 100644 --- a/modules/models.py +++ b/modules/models.py @@ -117,7 +117,7 @@ def load_model(model_name): model = eval(command) # Loading the tokenizer - if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists(): + if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) else: tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) diff --git a/modules/shared.py b/modules/shared.py index e94b5b65..d59cee99 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,12 @@ history = {'internal': [], 'visible': []} character = 'None' stop_everything = False +# UI elements (buttons, sliders, HTML, etc) +gradio = {} + +# Generation input parameters +input_params = [] + settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, @@ -25,6 +31,9 @@ settings = { 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, 'chat_prompt_size_max': 2048, + 'chat_generation_attempts': 1, + 'chat_generation_attempts_min': 1, + 'chat_generation_attempts_max': 5, 'preset_pygmalion': 'Pygmalion', 'name1_pygmalion': 'You', 'name2_pygmalion': 'Kawaii', @@ -37,7 +46,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') -parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') diff --git a/modules/text_generation.py b/modules/text_generation.py index 02d1210d..834101ff 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -72,14 +72,14 @@ def formatted_outputs(reply, model_name): else: return reply -def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): +def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): original_question = question if not (shared.args.chat or shared.args.cai_chat): question = apply_extensions(question, "input") if shared.args.verbose: print(f"\n\n{question}\n--------------------\n") - input_ids = encode(question, tokens) + input_ids = encode(question, max_new_tokens) cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" if not shared.args.flexgen: n = shared.tokenizer.eos_token_id if eos_token is None else shared.tokenizer.encode(eos_token, return_tensors='pt')[0][-1] @@ -126,9 +126,9 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top if shared.args.deepspeed: generate_params.append("synced_gpus=True") if shared.args.no_stream: - generate_params.append(f"max_new_tokens=tokens") + generate_params.append("max_new_tokens=max_new_tokens") else: - generate_params.append(f"max_new_tokens=8") + generate_params.append("max_new_tokens=8") if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) @@ -156,7 +156,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top # Generate the reply 8 tokens at a time else: yield formatted_outputs(original_question, shared.model_name) - for i in tqdm(range(tokens//8+1)): + for i in tqdm(range(max_new_tokens//8+1)): with torch.no_grad(): output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] if shared.soft_prompt: diff --git a/server.py b/server.py index 004daa2c..e7151933 100644 --- a/server.py +++ b/server.py @@ -19,7 +19,7 @@ from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: - print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n") + print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n') # Loading custom settings if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -34,13 +34,13 @@ def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) def get_available_characters(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) + return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) def get_available_extensions(): return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) def get_available_softprompts(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) + return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) def load_model_wrapper(selected_model): if selected_model != shared.model_name: @@ -100,50 +100,49 @@ def create_settings_menus(): with gr.Row(): with gr.Column(): with gr.Row(): - model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') - ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") + shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') + ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button') with gr.Column(): with gr.Row(): - preset_menu = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') - ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") + shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') + ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"): + with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'): with gr.Row(): - do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample") - temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature") + shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') with gr.Row(): - top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k") - top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p") + shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') + shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') with gr.Row(): - repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty") - no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size") + shared.gradio['repetition_penalty'] = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') + shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') with gr.Row(): - typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p") - min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if shared.args.no_stream else 0, label="min_length", interactive=shared.args.no_stream) + shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') + shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) - gr.Markdown("Contrastive search:") - penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha") + gr.Markdown('Contrastive search:') + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - gr.Markdown("Beam search (uses a lot of VRAM):") + gr.Markdown('Beam search (uses a lot of VRAM):') with gr.Row(): - num_beams = gr.Slider(1, 20, step=1, value=generate_params["num_beams"], label="num_beams") - length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty") - early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping") + shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') + shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') + shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Accordion("Soft prompt", open=False, elem_id="accordion"): + with gr.Accordion('Soft prompt', open=False, elem_id='accordion'): with gr.Row(): - softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') - ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") + shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') + ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button') gr.Markdown('Upload a soft prompt (.zip format):') with gr.Row(): - upload_softprompt = gr.File(type='binary', file_types=[".zip"]) + shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True) - preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) - softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True) - upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu]) - return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping + shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']]) + shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) + shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) available_models = get_available_models() available_presets = get_available_presets() @@ -159,25 +158,24 @@ if shared.args.model is not None: shared.model_name = shared.args.model else: if len(available_models) == 0: - print("No models are available! Please download at least one.") + print('No models are available! Please download at least one.') sys.exit(0) elif len(available_models) == 1: i = 0 else: - print("The following models are available:\n") + print('The following models are available:\n') for i, model in enumerate(available_models): - print(f"{i+1}. {model}") - print(f"\nWhich one do you want to load? 1-{len(available_models)}\n") + print(f'{i+1}. {model}') + print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') i = int(input())-1 print() shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) # UI settings -buttons = {} gen_events = [] suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' -description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" +description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): default_text = shared.settings['prompt_gpt4chan'] elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None: @@ -186,176 +184,169 @@ else: default_text = shared.settings['prompt'] if shared.args.chat or shared.args.cai_chat: - with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface: - interface.load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) + with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']: if shared.args.cai_chat: - display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) else: - display = gr.Chatbot(value=shared.history['visible']) - textbox = gr.Textbox(label='Input') + shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']) + shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): - buttons["Stop"] = gr.Button("Stop") - buttons["Generate"] = gr.Button("Generate") - buttons["Regenerate"] = gr.Button("Regenerate") + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Regenerate'] = gr.Button('Regenerate') with gr.Row(): - buttons["Impersonate"] = gr.Button("Impersonate") - buttons["Remove last"] = gr.Button("Remove last") - buttons["Clear history"] = gr.Button("Clear history") + shared.gradio['Impersonate'] = gr.Button('Impersonate') + shared.gradio['Remove last'] = gr.Button('Remove last') + shared.gradio['Clear history'] = gr.Button('Clear history') with gr.Row(): - buttons["Send last reply to input"] = gr.Button("Send last reply to input") - buttons["Replace last reply"] = gr.Button("Replace last reply") - if shared.args.picture: + shared.gradio['Send last reply to input'] = gr.Button('Send last reply to input') + shared.gradio['Replace last reply'] = gr.Button('Replace last reply') + with gr.Tab('Chat settings'): + shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') + shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context') with gr.Row(): - picture_select = gr.Image(label="Send a picture", type='pil') - - with gr.Tab("Chat settings"): - name1 = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') - name2 = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') - context = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context') - with gr.Row(): - character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character') - ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") + shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character') + ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') with gr.Row(): - check = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') + shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') with gr.Row(): with gr.Tab('Chat history'): with gr.Row(): with gr.Column(): gr.Markdown('Upload') - upload_chat_history = gr.File(type='binary', file_types=[".json", ".txt"]) + shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt']) with gr.Column(): gr.Markdown('Download') - download = gr.File() - buttons["Download"] = gr.Button(value="Click me") + shared.gradio['download'] = gr.File() + shared.gradio['download_button'] = gr.Button(value='Click me') with gr.Tab('Upload character'): with gr.Row(): with gr.Column(): gr.Markdown('1. Select the JSON file') - upload_char = gr.File(type='binary', file_types=[".json"]) + shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json']) with gr.Column(): gr.Markdown('2. Select your character\'s profile picture (optional)') - upload_img = gr.File(type='binary', file_types=["image"]) - buttons["Upload character"] = gr.Button(value="Submit") + shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image']) + shared.gradio['Upload character'] = gr.Button(value='Submit') with gr.Tab('Upload your profile picture'): - upload_img_me = gr.File(type='binary', file_types=["image"]) + shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image']) with gr.Tab('Upload TavernAI Character Card'): - upload_img_tavern = gr.File(type='binary', file_types=["image"]) + shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - with gr.Tab("Generation settings"): + with gr.Tab('Generation settings'): with gr.Row(): with gr.Column(): - max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) with gr.Column(): - chat_prompt_size_slider = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() + shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts') + create_settings_menus() + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] if shared.args.extensions is not None: - with gr.Tab("Extensions"): + with gr.Tab('Extensions'): extensions_module.create_extensions_block() - input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] - if shared.args.picture: - input_params.append(picture_select) - function_call = "chat.cai_chatbot_wrapper" if shared.args.cai_chat else "chat.chatbot_wrapper" + function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=shared.args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=shared.args.no_stream)) - if shared.args.picture: - picture_select.upload(eval(function_call), input_params, display, show_progress=shared.args.no_stream) - gen_events.append(buttons["Regenerate"].click(chat.regenerate_wrapper, input_params, display, show_progress=shared.args.no_stream)) - gen_events.append(buttons["Impersonate"].click(chat.impersonate_wrapper, input_params, textbox, show_progress=shared.args.no_stream)) - buttons["Stop"].click(chat.stop_everything_event, [], [], cancels=gen_events) + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) - buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream) - buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream) - buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display) - buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False) - buttons["Download"].click(chat.save_history, inputs=[], outputs=[download]) - buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu]) + shared.gradio['Send last reply to input'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) + shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) + shared.gradio['Clear history'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) + shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) + shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) + shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) # Clearing stuff and saving the history - for i in ["Generate", "Regenerate", "Replace last reply"]: - buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) - buttons[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - buttons["Clear history"].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - textbox.submit(lambda x: "", textbox, textbox, show_progress=False) - textbox.submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + for i in ['Generate', 'Regenerate', 'Replace last reply']: + shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) + shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['Clear history'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) + shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - character_menu.change(chat.load_character, [character_menu, name1, name2], [name2, context, display]) - upload_chat_history.upload(chat.load_history, [upload_chat_history, name1, name2], []) - upload_img_tavern.upload(chat.upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) - upload_img_me.upload(chat.upload_your_profile_picture, [upload_img_me], []) - if shared.args.picture: - picture_select.upload(lambda : None, [], [picture_select], show_progress=False) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']]) + shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) + shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) + shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] - reload_inputs = [name1, name2] if shared.args.cai_chat else [] - upload_chat_history.upload(reload_func, reload_inputs, [display]) - upload_img_me.upload(reload_func, reload_inputs, [display]) - interface.load(reload_func, reload_inputs, [display], show_progress=True) + reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] + shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + + shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) + shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: - with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css, analytics_enabled=False) as shared.gradio['interface']: gr.Markdown(description) with gr.Tab('Raw'): - textbox = gr.Textbox(value=default_text, lines=23) + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23) with gr.Tab('Markdown'): - markdown = gr.Markdown() + shared.gradio['markdown'] = gr.Markdown() with gr.Tab('HTML'): - html = gr.HTML() + shared.gradio['html'] = gr.HTML() - buttons["Generate"] = gr.Button("Generate") - buttons["Stop"] = gr.Button("Stop") - - max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + create_settings_menus() if shared.args.extensions is not None: extensions_module.create_extensions_block() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream)) - buttons["Stop"].click(None, None, None, cancels=gen_events) + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(None, None, None, cancels=gen_events) else: - with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css, analytics_enabled=False) as shared.gradio['interface']: gr.Markdown(description) with gr.Row(): with gr.Column(): - textbox = gr.Textbox(value=default_text, lines=15, label='Input') - max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - buttons["Generate"] = gr.Button("Generate") + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['Generate'] = gr.Button('Generate') with gr.Row(): with gr.Column(): - buttons["Continue"] = gr.Button("Continue") + shared.gradio['Continue'] = gr.Button('Continue') with gr.Column(): - buttons["Stop"] = gr.Button("Stop") + shared.gradio['Stop'] = gr.Button('Stop') - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() + create_settings_menus() if shared.args.extensions is not None: extensions_module.create_extensions_block() with gr.Column(): with gr.Tab('Raw'): - output_textbox = gr.Textbox(lines=15, label='Output') + shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output') with gr.Tab('Markdown'): - markdown = gr.Markdown() + shared.gradio['markdown'] = gr.Markdown() with gr.Tab('HTML'): - html = gr.HTML() + shared.gradio['html'] = gr.HTML() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) - gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) - buttons["Stop"].click(None, None, None, cancels=gen_events) + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(None, None, None, cancels=gen_events) -interface.queue() +shared.gradio['interface'].queue() if shared.args.listen: - interface.launch(prevent_thread_lock=True, share=shared.args.share, server_name="0.0.0.0", server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port) else: - interface.launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) # I think that I will need this later while True: diff --git a/settings-template.json b/settings-template.json index dae76960..13165641 100644 --- a/settings-template.json +++ b/settings-template.json @@ -12,6 +12,9 @@ "chat_prompt_size": 2048, "chat_prompt_size_min": 0, "chat_prompt_size_max": 2048, + "chat_generation_attempts": 1, + "chat_generation_attempts_min": 1, + "chat_generation_attempts_max": 5, "preset_pygmalion": "Pygmalion", "name1_pygmalion": "You", "name2_pygmalion": "Kawaii",