diff --git a/api-example.py b/api-example.py index a5967a9d..f2f6c51e 100644 --- a/api-example.py +++ b/api-example.py @@ -10,7 +10,6 @@ Optionally, you can also add the --share flag to generate a public gradio URL, allowing you to use the API remotely. ''' - import requests # Server address diff --git a/convert-to-flexgen.py b/convert-to-flexgen.py index 18afa9bd..917f023c 100644 --- a/convert-to-flexgen.py +++ b/convert-to-flexgen.py @@ -6,15 +6,13 @@ Converts a transformers model to a format compatible with flexgen. import argparse import os -import numpy as np from pathlib import Path -from sys import argv +import numpy as np import torch from tqdm import tqdm -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - +from transformers import AutoModelForCausalLM, AutoTokenizer + parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") args = parser.parse_args() @@ -33,7 +31,6 @@ def disable_torch_init(): torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - def restore_torch_init(): """Rollback the change made by disable_torch_init.""" import torch diff --git a/convert-to-safetensors.py b/convert-to-safetensors.py index 60770843..63baaa97 100644 --- a/convert-to-safetensors.py +++ b/convert-to-safetensors.py @@ -13,12 +13,10 @@ https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303 import argparse from pathlib import Path -from sys import argv import torch -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - +from transformers import AutoModelForCausalLM, AutoTokenizer + parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).') diff --git a/modules/bot_picture.py b/modules/bot_picture.py index f407c379..dd4d73eb 100644 --- a/modules/bot_picture.py +++ b/modules/bot_picture.py @@ -1,8 +1,5 @@ -import requests import torch -from PIL import Image -from transformers import BlipForConditionalGeneration -from transformers import BlipProcessor +from transformers import BlipForConditionalGeneration, BlipProcessor processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") diff --git a/modules/chat.py b/modules/chat.py new file mode 100644 index 00000000..356119a2 --- /dev/null +++ b/modules/chat.py @@ -0,0 +1,366 @@ +import base64 +import copy +import io +import json +import re +from datetime import datetime +from io import BytesIO +from pathlib import Path + +from PIL import Image + +import modules.shared as shared +from modules.extensions import apply_extensions +from modules.html_generator import generate_chat_html +from modules.text_generation import encode, generate_reply, get_max_prompt_length + +if shared.args.picture and (shared.args.cai_chat or shared.args.chat): + import modules.bot_picture as bot_picture + +# This gets the new line characters right. +def clean_chat_message(text): + text = text.replace('\n', '\n\n') + text = re.sub(r"\n{3,}", "\n\n", text) + text = text.strip() + return text + +def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False): + text = clean_chat_message(text) + rows = [f"{context.strip()}\n"] + i = len(shared.history['internal'])-1 + count = 0 + + if shared.soft_prompt: + chat_prompt_size -= shared.soft_prompt_tensor.shape[1] + max_length = min(get_max_prompt_length(tokens), chat_prompt_size) + + while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: + rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") + count += 1 + if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): + rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n") + count += 1 + i -= 1 + + if not impersonate: + rows.append(f"{name1}: {text}\n") + rows.append(apply_extensions(f"{name2}:", "bot_prefix")) + limit = 3 + else: + rows.append(f"{name1}:") + limit = 2 + + while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: + rows.pop(1) + rows.pop(1) + + question = ''.join(rows) + return question + +def extract_message_from_reply(question, reply, current, other, check, extensions=False): + next_character_found = False + substring_found = False + + previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)] + idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)] + idx = idx[len(previous_idx)-1] + + if extensions: + reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):] + else: + reply = reply[idx + 1 + len(f"{current}:"):] + + if check: + reply = reply.split('\n')[0].strip() + else: + idx = reply.find(f"\n{other}:") + if idx != -1: + reply = reply[:idx] + next_character_found = True + reply = clean_chat_message(reply) + + # Detect if something like "\nYo" is generated just before + # "\nYou:" is completed + tmp = f"\n{other}:" + for j in range(1, len(tmp)): + if reply[-j:] == tmp[:j]: + substring_found = True + + return reply, next_character_found, substring_found + +def generate_chat_picture(picture, name1, name2): + text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*' + buffer = BytesIO() + picture.save(buffer, format="JPEG") + img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') + visible_text = f'' + return text, visible_text + +def stop_everything_event(): + shared.stop_everything = True + +def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + shared.stop_everything = False + + if 'pygmalion' in shared.model_name.lower(): + name1 = "You" + + if shared.args.picture and picture is not None: + text, visible_text = generate_chat_picture(picture, name1, name2) + else: + visible_text = text + if shared.args.chat: + visible_text = visible_text.replace('\n', '
') + + text = apply_extensions(text, "input") + question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size) + eos_token = '\n' if check else None + first = True + for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): + reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) + visible_reply = apply_extensions(reply, "output") + if shared.args.chat: + visible_reply = visible_reply.replace('\n', '
') + + # We need this global variable to handle the Stop event, + # otherwise gradio gets confused + if shared.stop_everything: + return shared.history['visible'] + + if first: + first = False + shared.history['internal'].append(['', '']) + shared.history['visible'].append(['', '']) + + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] + if not substring_found: + yield shared.history['visible'] + if next_character_found: + break + yield shared.history['visible'] + +def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + if 'pygmalion' in shared.model_name.lower(): + name1 = "You" + + question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) + eos_token = '\n' if check else None + for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) + if not substring_found: + yield reply + if next_character_found: + break + yield reply + +def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): + yield generate_chat_html(_history, name1, name2, shared.character) + +def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + if shared.character is not None and len(shared.history['visible']) == 1: + if shared.args.cai_chat: + yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) + else: + yield shared.history['visible'] + else: + last_visible = shared.history['visible'].pop() + last_internal = shared.history['internal'].pop() + + for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): + if shared.args.cai_chat: + shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] + yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) + else: + shared.history['visible'][-1] = (last_visible[0], _history[-1][1]) + yield shared.history['visible'] + +def remove_last_message(name1, name2): + if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': + last = shared.history['visible'].pop() + shared.history['internal'].pop() + else: + last = ['', ''] + if shared.args.cai_chat: + return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] + else: + return shared.history['visible'], last[0] + +def send_last_reply_to_input(): + if len(shared.history['internal']) > 0: + return shared.history['internal'][-1][1] + else: + return '' + +def replace_last_reply(text, name1, name2): + if len(shared.history['visible']) > 0: + if shared.args.cai_chat: + shared.history['visible'][-1][1] = text + else: + shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) + shared.history['internal'][-1][1] = apply_extensions(text, "input") + + if shared.args.cai_chat: + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) + else: + return shared.history['visible'] + +def clear_html(): + return generate_chat_html([], "", "", shared.character) + +def clear_chat_log(name1, name2): + if shared.character != 'None': + for i in range(len(shared.history['internal'])): + if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]: + shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]] + shared.history['internal'] = shared.history['internal'][:i+1] + break + else: + shared.history['internal'] = [] + shared.history['visible'] = [] + if shared.args.cai_chat: + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) + else: + return shared.history['visible'] + +def redraw_html(name1, name2): + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) + +def tokenize_dialogue(dialogue, name1, name2): + _history = [] + + dialogue = re.sub('', '', dialogue) + dialogue = re.sub('', '', dialogue) + dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) + dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue) + idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)] + if len(idx) == 0: + return _history + + messages = [] + for i in range(len(idx)-1): + messages.append(dialogue[idx[i]:idx[i+1]].strip()) + messages.append(dialogue[idx[-1]:].strip()) + + entry = ['', ''] + for i in messages: + if i.startswith(f'{name1}:'): + entry[0] = i[len(f'{name1}:'):].strip() + elif i.startswith(f'{name2}:'): + entry[1] = i[len(f'{name2}:'):].strip() + if not (len(entry[0]) == 0 and len(entry[1]) == 0): + _history.append(entry) + entry = ['', ''] + + print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') + for row in _history: + for column in row: + print("\n") + for line in column.strip().split('\n'): + print("| "+line+"\n") + print("|\n") + print("------------------------------") + + return _history + +def save_history(timestamp=True): + if timestamp: + fname = f"{shared.character or ''}{'_' if shared.character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" + else: + fname = f"{shared.character or ''}{'_' if shared.character else ''}persistent.json" + if not Path('logs').exists(): + Path('logs').mkdir() + with open(Path(f'logs/{fname}'), 'w') as f: + f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) + return Path(f'logs/{fname}') + +def load_history(file, name1, name2): + file = file.decode('utf-8') + try: + j = json.loads(file) + if 'data' in j: + shared.history['internal'] = j['data'] + if 'data_visible' in j: + shared.history['visible'] = j['data_visible'] + else: + shared.history['visible'] = copy.deepcopy(shared.history['internal']) + # Compatibility with Pygmalion AI's official web UI + elif 'chat' in j: + shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] + if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): + shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)] + shared.history['visible'] = copy.deepcopy(shared.history['internal']) + shared.history['visible'][0][0] = '' + else: + shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)] + shared.history['visible'] = copy.deepcopy(shared.history['internal']) + except: + shared.history['internal'] = tokenize_dialogue(file, name1, name2) + shared.history['visible'] = copy.deepcopy(shared.history['internal']) + +def load_character(_character, name1, name2): + context = "" + shared.history['internal'] = [] + shared.history['visible'] = [] + if _character != 'None': + shared.character = _character + data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) + name2 = data['char_name'] + if 'char_persona' in data and data['char_persona'] != '': + context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" + if 'world_scenario' in data and data['world_scenario'] != '': + context += f"Scenario: {data['world_scenario']}\n" + context = f"{context.strip()}\n\n" + if 'example_dialogue' in data and data['example_dialogue'] != '': + shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2) + if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] + shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]] + else: + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] + shared.history['visible'] += [['', "Hello there!"]] + else: + shared.character = None + context = shared.settings['context_pygmalion'] + name2 = shared.settings['name2_pygmalion'] + + if Path(f'logs/{shared.character}_persistent.json').exists(): + load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) + + if shared.args.cai_chat: + return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) + else: + return name2, context, shared.history['visible'] + +def upload_character(json_file, img, tavern=False): + json_file = json_file if type(json_file) == str else json_file.decode('utf-8') + data = json.loads(json_file) + outfile_name = data["char_name"] + i = 1 + while Path(f'characters/{outfile_name}.json').exists(): + outfile_name = f'{data["char_name"]}_{i:03d}' + i += 1 + if tavern: + outfile_name = f'TavernAI-{outfile_name}' + with open(Path(f'characters/{outfile_name}.json'), 'w') as f: + f.write(json_file) + if img is not None: + img = Image.open(io.BytesIO(img)) + img.save(Path(f'characters/{outfile_name}.png')) + print(f'New character saved to "characters/{outfile_name}.json".') + return outfile_name + +def upload_tavern_character(img, name1, name2): + _img = Image.open(io.BytesIO(img)) + _img.getexif() + decoded_string = base64.b64decode(_img.info['chara']) + _json = json.loads(decoded_string) + _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} + _json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name']) + return upload_character(json.dumps(_json), img, tavern=True) + +def upload_your_profile_picture(img): + img = Image.open(io.BytesIO(img)) + img.save(Path(f'img_me.png')) + print(f'Profile picture saved to "img_me.png"') diff --git a/modules/extensions.py b/modules/extensions.py new file mode 100644 index 00000000..1ab07761 --- /dev/null +++ b/modules/extensions.py @@ -0,0 +1,64 @@ +import extensions +import modules.shared as shared +import gradio as gr + +extension_state = {} +available_extensions = [] + +def load_extensions(): + global extension_state + for i,ext in enumerate(shared.args.extensions.split(',')): + if ext in available_extensions: + print(f'Loading the extension "{ext}"... ', end='') + ext_string = f"extensions.{ext}.script" + exec(f"import {ext_string}") + extension_state[ext] = [True, i] + print(f'Ok.') + +def apply_extensions(text, typ): + for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): + if extension_state[ext][0] == True: + ext_string = f"extensions.{ext}.script" + if typ == "input" and hasattr(eval(ext_string), "input_modifier"): + text = eval(f"{ext_string}.input_modifier(text)") + elif typ == "output" and hasattr(eval(ext_string), "output_modifier"): + text = eval(f"{ext_string}.output_modifier(text)") + elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"): + text = eval(f"{ext_string}.bot_prefix_modifier(text)") + return text + +def update_extensions_parameters(*kwargs): + i = 0 + for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): + if extension_state[ext][0] == True: + params = eval(f"extensions.{ext}.script.params") + for param in params: + if len(kwargs) >= i+1: + params[param] = eval(f"kwargs[{i}]") + i += 1 + +def get_params(name): + return eval(f"extensions.{name}.script.params") + +def create_extensions_block(): + extensions_ui_elements = [] + default_values = [] + if not (shared.args.chat or shared.args.cai_chat): + gr.Markdown('## Extensions parameters') + for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): + if extension_state[ext][0] == True: + params = get_params(ext) + for param in params: + _id = f"{ext}-{param}" + default_value = shared.settings[_id] if _id in shared.settings else params[param] + default_values.append(default_value) + if type(params[param]) == str: + extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}")) + elif type(params[param]) in [int, float]: + extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}")) + elif type(params[param]) == bool: + extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}")) + + update_extensions_parameters(*default_values) + btn_extensions = gr.Button("Apply") + btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) diff --git a/modules/html_generator.py b/modules/html_generator.py index f0e26392..6e1fb8ac 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -5,7 +5,6 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML. ''' import base64 -import copy import os import re from io import BytesIO diff --git a/modules/models.py b/modules/models.py new file mode 100644 index 00000000..efa3eb25 --- /dev/null +++ b/modules/models.py @@ -0,0 +1,150 @@ +import json +import os +import time +import zipfile +from pathlib import Path + +import numpy as np +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modules.shared as shared + +transformers.logging.set_verbosity_error() + +local_rank = None + +if shared.args.flexgen: + from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy, + TorchDevice, TorchDisk, TorchMixedDevice, + get_opt_config) + +if shared.args.deepspeed: + import deepspeed + from transformers.deepspeed import (HfDeepSpeedConfig, + is_deepspeed_zero3_enabled) + + from modules.deepspeed_parameters import generate_ds_config + + # Distributed setup + local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + torch.cuda.set_device(local_rank) + deepspeed.init_distributed() + ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) + dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration + +def load_model(model_name): + print(f"Loading {model_name}...") + t0 = time.time() + + # Default settings + if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen): + if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) + else: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda() + + # FlexGen + elif shared.args.flexgen: + gpu = TorchDevice("cuda:0") + cpu = TorchDevice("cpu") + disk = TorchDisk(shared.args.disk_cache_dir) + env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + + # Offloading policy + policy = Policy(1, 1, + shared.args.percent[0], shared.args.percent[1], + shared.args.percent[2], shared.args.percent[3], + shared.args.percent[4], shared.args.percent[5], + overlap=True, sep_layer=True, pin_weight=True, + cpu_cache_compute=False, attn_sparsity=1.0, + compress_weight=shared.args.compress_weight, + comp_weight_config=CompressionConfig( + num_bits=4, group_size=64, + group_dim=0, symmetric=False), + compress_cache=False, + comp_cache_config=CompressionConfig( + num_bits=4, group_size=64, + group_dim=2, symmetric=False)) + + opt_config = get_opt_config(f"facebook/{shared.model_name}") + model = OptLM(opt_config, env, "models", policy) + model.init_all_weights() + + # DeepSpeed ZeRO-3 + elif shared.args.deepspeed: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] + model.module.eval() # Inference + print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") + + # Custom + else: + command = "AutoModelForCausalLM.from_pretrained" + params = ["low_cpu_mem_usage=True"] + if not shared.args.cpu and not torch.cuda.is_available(): + print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n") + shared.args.cpu = True + + if shared.args.cpu: + params.append("low_cpu_mem_usage=True") + params.append("torch_dtype=torch.float32") + else: + params.append("device_map='auto'") + params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") + + if shared.args.gpu_memory: + params.append(f"max_memory={{0: '{shared.args.gpu_memory or '99'}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") + elif not shared.args.load_in_8bit: + total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) + suggestion = round((total_mem-1000)/1000)*1000 + if total_mem-suggestion < 800: + suggestion -= 1000 + suggestion = int(round(suggestion/1000)) + print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") + params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") + if shared.args.disk: + params.append(f"offload_folder='{shared.args.disk_cache_dir}'") + + command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" + model = eval(command) + + # Loading the tokenizer + if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists(): + tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) + else: + tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) + tokenizer.truncation_side = 'left' + + print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + return model, tokenizer + +def load_soft_prompt(name): + if name == 'None': + shared.soft_prompt = False + shared.soft_prompt_tensor = None + else: + with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf: + zf.extract('tensor.npy') + zf.extract('meta.json') + j = json.loads(open('meta.json', 'r').read()) + print(f"\nLoading the softprompt \"{name}\".") + for field in j: + if field != 'name': + if type(j[field]) is list: + print(f"{field}: {', '.join(j[field])}") + else: + print(f"{field}: {j[field]}") + print() + tensor = np.load('tensor.npy') + Path('tensor.npy').unlink() + Path('meta.json').unlink() + tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype) + tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) + + shared.soft_prompt = True + shared.soft_prompt_tensor = tensor + + return name diff --git a/modules/shared.py b/modules/shared.py new file mode 100644 index 00000000..68c2fb78 --- /dev/null +++ b/modules/shared.py @@ -0,0 +1,62 @@ +import argparse + +model = None +tokenizer = None +model_name = "" +soft_prompt_tensor = None +soft_prompt = False + +# Chat variables +history = {'internal': [], 'visible': []} +character = 'None' +stop_everything = False + +settings = { + 'max_new_tokens': 200, + 'max_new_tokens_min': 1, + 'max_new_tokens_max': 2000, + 'preset': 'NovelAI-Sphinx Moth', + 'name1': 'Person 1', + 'name2': 'Person 2', + 'context': 'This is a conversation between two people.', + 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', + 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', + 'stop_at_newline': True, + 'chat_prompt_size': 2048, + 'chat_prompt_size_min': 0, + 'chat_prompt_size_max': 2048, + 'preset_pygmalion': 'Pygmalion', + 'name1_pygmalion': 'You', + 'name2_pygmalion': 'Kawaii', + 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", + 'stop_at_newline_pygmalion': False, +} + +parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) +parser.add_argument('--model', type=str, help='Name of the model to load by default.') +parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') +parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') +parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') +parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.') +parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') +parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') +parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') +parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') +parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') +parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".') +parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') +parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') +parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') +parser.add_argument('--percent', nargs="+", type=int, default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).') +parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.") +parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') +parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') +parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') +parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') +parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') +parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".') +parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') +parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') +parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') +parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +args = parser.parse_args() diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py index 3baadf6c..44a631b3 100644 --- a/modules/stopping_criteria.py +++ b/modules/stopping_criteria.py @@ -8,6 +8,7 @@ https://github.com/PygmalionAI/gradio-ui/ import torch import transformers + class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): def __init__(self, sentinel_token_ids: torch.LongTensor, diff --git a/modules/text_generation.py b/modules/text_generation.py new file mode 100644 index 00000000..d0204102 --- /dev/null +++ b/modules/text_generation.py @@ -0,0 +1,178 @@ +import re +import time + +import numpy as np +import torch +import transformers +from tqdm import tqdm + +import modules.shared as shared +from modules.extensions import apply_extensions +from modules.html_generator import generate_4chan_html, generate_basic_html +from modules.models import local_rank +from modules.stopping_criteria import _SentinelTokenStoppingCriteria + + +def get_max_prompt_length(tokens): + max_length = 2048-tokens + if shared.soft_prompt: + max_length -= shared.soft_prompt_tensor.shape[1] + return max_length + +def encode(prompt, tokens_to_generate=0, add_special_tokens=True): + input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) + if shared.args.cpu or shared.args.flexgen: + return input_ids + elif shared.args.deepspeed: + return input_ids.to(device=local_rank) + else: + return input_ids.cuda() + +def decode(output_ids): + reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + return reply + +def generate_softprompt_input_tensors(input_ids): + inputs_embeds = shared.model.transformer.wte(input_ids) + inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) + filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device) + filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens + return inputs_embeds, filler_input_ids + +# Removes empty replies from gpt4chan outputs +def fix_gpt4chan(s): + for i in range(10): + s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) + s = re.sub("--- [0-9]*\n *\n---", "---", s) + s = re.sub("--- [0-9]*\n\n\n---", "---", s) + return s + +# Fix the LaTeX equations in galactica +def fix_galactica(s): + s = s.replace(r'\[', r'$') + s = s.replace(r'\]', r'$') + s = s.replace(r'\(', r'$') + s = s.replace(r'\)', r'$') + s = s.replace(r'$$', r'$') + s = re.sub(r'\n', r'\n\n', s) + s = re.sub(r"\n{3,}", "\n\n", s) + return s + +def formatted_outputs(reply, model_name): + if not (shared.args.chat or shared.args.cai_chat): + if shared.model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) + return reply, reply, generate_basic_html(reply) + elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + reply = fix_gpt4chan(reply) + return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + else: + return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + else: + return reply + +def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): + original_question = question + if not (shared.args.chat or shared.args.cai_chat): + question = apply_extensions(question, "input") + if shared.args.verbose: + print(f"\n\n{question}\n--------------------\n") + + input_ids = encode(question, tokens) + cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" + if not shared.args.flexgen: + n = shared.tokenizer.eos_token_id if eos_token is None else shared.tokenizer.encode(eos_token, return_tensors='pt')[0][-1] + else: + n = shared.tokenizer(eos_token).input_ids[0] if eos_token else None + + if stopping_string is not None: + # The stopping_criteria code below was copied from + # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py + t = encode(stopping_string, 0, add_special_tokens=False) + stopping_criteria_list = transformers.StoppingCriteriaList([ + _SentinelTokenStoppingCriteria( + sentinel_token_ids=t, + starting_idx=len(input_ids[0]) + ) + ]) + else: + stopping_criteria_list = None + + if not shared.args.flexgen: + generate_params = [ + f"eos_token_id={n}", + f"stopping_criteria=stopping_criteria_list", + f"do_sample={do_sample}", + f"temperature={temperature}", + f"top_p={top_p}", + f"typical_p={typical_p}", + f"repetition_penalty={repetition_penalty}", + f"top_k={top_k}", + f"min_length={min_length if shared.args.no_stream else 0}", + f"no_repeat_ngram_size={no_repeat_ngram_size}", + f"num_beams={num_beams}", + f"penalty_alpha={penalty_alpha}", + f"length_penalty={length_penalty}", + f"early_stopping={early_stopping}", + ] + else: + generate_params = [ + f"do_sample={do_sample}", + f"temperature={temperature}", + f"stop={n}", + ] + + if shared.args.deepspeed: + generate_params.append("synced_gpus=True") + if shared.args.no_stream: + generate_params.append(f"max_new_tokens=tokens") + else: + generate_params.append(f"max_new_tokens=8") + + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + generate_params.insert(0, "inputs_embeds=inputs_embeds") + generate_params.insert(0, "filler_input_ids") + else: + generate_params.insert(0, "input_ids") + + # Generate the entire reply at once + if shared.args.no_stream: + t0 = time.time() + with torch.no_grad(): + output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + reply = decode(output) + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) + + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") + + # Generate the reply 8 tokens at a time + else: + yield formatted_outputs(original_question, shared.model_name) + for i in tqdm(range(tokens//8+1)): + with torch.no_grad(): + output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + reply = decode(output) + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) + + if not shared.args.flexgen: + input_ids = torch.reshape(output, (1, output.shape[0])) + else: + input_ids = np.reshape(output, (1, output.shape[0])) + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + + if output[-1] == n: + break diff --git a/server.py b/server.py index d1066917..ee5c9372 100644 --- a/server.py +++ b/server.py @@ -1,251 +1,55 @@ -import argparse -import base64 -import copy import gc -import glob import io import json -import os import re import sys import time -import warnings import zipfile -from datetime import datetime from pathlib import Path import gradio as gr -import numpy as np import torch -import transformers -from PIL import Image -from tqdm import tqdm -from transformers import AutoConfig -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer -from io import BytesIO -from modules.html_generator import * -from modules.stopping_criteria import _SentinelTokenStoppingCriteria -from modules.ui import * +import modules.chat as chat +import modules.extensions as extensions_module +import modules.shared as shared +import modules.ui as ui +from modules.html_generator import generate_chat_html +from modules.models import load_model, load_soft_prompt +from modules.text_generation import generate_reply -transformers.logging.set_verbosity_error() - -parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) -parser.add_argument('--model', type=str, help='Name of the model to load by default.') -parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') -parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') -parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') -parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.') -parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') -parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') -parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') -parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') -parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') -parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".') -parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') -parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') -parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') -parser.add_argument('--percent', nargs="+", type=int, default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).') -parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.") -parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') -parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') -parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') -parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') -parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') -parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".') -parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') -parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') -parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') -parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') -args = parser.parse_args() - -if (args.chat or args.cai_chat) and not args.no_stream: +if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n") -settings = { - 'max_new_tokens': 200, - 'max_new_tokens_min': 1, - 'max_new_tokens_max': 2000, - 'preset': 'NovelAI-Sphinx Moth', - 'name1': 'Person 1', - 'name2': 'Person 2', - 'context': 'This is a conversation between two people.', - 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', - 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', - 'stop_at_newline': True, - 'chat_prompt_size': 2048, - 'chat_prompt_size_min': 0, - 'chat_prompt_size_max': 2048, - 'preset_pygmalion': 'Pygmalion', - 'name1_pygmalion': 'You', - 'name2_pygmalion': 'Kawaii', - 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - 'stop_at_newline_pygmalion': False, -} - -if args.settings is not None and Path(args.settings).exists(): - new_settings = json.loads(open(Path(args.settings), 'r').read()) +# Loading custom settings +if shared.args.settings is not None and Path(shared.args.settings).exists(): + new_settings = json.loads(open(Path(shared.args.settings), 'r').read()) for item in new_settings: - settings[item] = new_settings[item] + shared.settings[item] = new_settings[item] -if args.flexgen: - from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, Task, get_opt_config) +def get_available_models(): + return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) -if args.deepspeed: - import deepspeed - from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled - from modules.deepspeed_parameters import generate_ds_config +def get_available_presets(): + return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) - # Distributed setup - local_rank = args.local_rank if args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - torch.cuda.set_device(local_rank) - deepspeed.init_distributed() - ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir) - dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration +def get_available_characters(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) -if args.picture and (args.cai_chat or args.chat): - import modules.bot_picture as bot_picture +def get_available_extensions(): + return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) -def load_model(model_name): - print(f"Loading {model_name}...") - t0 = time.time() - - # Default settings - if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed or args.flexgen): - if any(size in model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) - else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda() - - # FlexGen - elif args.flexgen: - gpu = TorchDevice("cuda:0") - cpu = TorchDevice("cpu") - disk = TorchDisk(args.disk_cache_dir) - env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) - - # Offloading policy - policy = Policy(1, 1, - args.percent[0], args.percent[1], - args.percent[2], args.percent[3], - args.percent[4], args.percent[5], - overlap=True, sep_layer=True, pin_weight=True, - cpu_cache_compute=False, attn_sparsity=1.0, - compress_weight=args.compress_weight, - comp_weight_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=0, symmetric=False), - compress_cache=False, - comp_cache_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=2, symmetric=False)) - - opt_config = get_opt_config(f"facebook/{model_name}") - model = OptLM(opt_config, env, "models", policy) - model.init_all_weights() - - # DeepSpeed ZeRO-3 - elif args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16) - model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] - model.module.eval() # Inference - print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") - - # Custom - else: - command = "AutoModelForCausalLM.from_pretrained" - params = ["low_cpu_mem_usage=True"] - if not args.cpu and not torch.cuda.is_available(): - print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n") - args.cpu = True - - if args.cpu: - params.append("low_cpu_mem_usage=True") - params.append("torch_dtype=torch.float32") - else: - params.append("device_map='auto'") - params.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.bfloat16" if args.bf16 else "torch_dtype=torch.float16") - - if args.gpu_memory: - params.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") - elif not args.load_in_8bit: - total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) - suggestion = round((total_mem-1000)/1000)*1000 - if total_mem-suggestion < 800: - suggestion -= 1000 - suggestion = int(round(suggestion/1000)) - print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") - params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") - if args.disk: - params.append(f"offload_folder='{args.disk_cache_dir}'") - - command = f"{command}(Path(f'models/{model_name}'), {', '.join(set(params))})" - model = eval(command) - - # Loading the tokenizer - if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists(): - tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) - else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/")) - tokenizer.truncation_side = 'left' - - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") - return model, tokenizer - -def load_soft_prompt(name): - global soft_prompt, soft_prompt_tensor - - if name == 'None': - soft_prompt = False - soft_prompt_tensor = None - else: - with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf: - zf.extract('tensor.npy') - zf.extract('meta.json') - j = json.loads(open('meta.json', 'r').read()) - print(f"\nLoading the softprompt \"{name}\".") - for field in j: - if field != 'name': - if type(j[field]) is list: - print(f"{field}: {', '.join(j[field])}") - else: - print(f"{field}: {j[field]}") - print() - tensor = np.load('tensor.npy') - Path('tensor.npy').unlink() - Path('meta.json').unlink() - tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype) - tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) - - soft_prompt = True - soft_prompt_tensor = tensor - - return name - -def upload_soft_prompt(file): - with zipfile.ZipFile(io.BytesIO(file)) as zf: - zf.extract('meta.json') - j = json.loads(open('meta.json', 'r').read()) - name = j['name'] - Path('meta.json').unlink() - - with open(Path(f'softprompts/{name}.zip'), 'wb') as f: - f.write(file) - - return name +def get_available_softprompts(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) def load_model_wrapper(selected_model): - global model_name, model, tokenizer - - if selected_model != model_name: - model_name = selected_model - model = tokenizer = None - if not args.cpu: + if selected_model != shared.model_name: + shared.model_name = selected_model + shared.model = shared.tokenizer = None + if not shared.args.cpu: gc.collect() torch.cuda.empty_cache() - model, tokenizer = load_model(model_name) + shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model @@ -278,246 +82,30 @@ def load_preset_values(preset_menu, return_dict=False): else: return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] -# Removes empty replies from gpt4chan outputs -def fix_gpt4chan(s): - for i in range(10): - s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) - s = re.sub("--- [0-9]*\n *\n---", "---", s) - s = re.sub("--- [0-9]*\n\n\n---", "---", s) - return s +def upload_soft_prompt(file): + with zipfile.ZipFile(io.BytesIO(file)) as zf: + zf.extract('meta.json') + j = json.loads(open('meta.json', 'r').read()) + name = j['name'] + Path('meta.json').unlink() -# Fix the LaTeX equations in galactica -def fix_galactica(s): - s = s.replace(r'\[', r'$') - s = s.replace(r'\]', r'$') - s = s.replace(r'\(', r'$') - s = s.replace(r'\)', r'$') - s = s.replace(r'$$', r'$') - s = re.sub(r'\n', r'\n\n', s) - s = re.sub(r"\n{3,}", "\n\n", s) - return s + with open(Path(f'softprompts/{name}.zip'), 'wb') as f: + f.write(file) -def get_max_prompt_length(tokens): - global soft_prompt, soft_prompt_tensor - max_length = 2048-tokens - if soft_prompt: - max_length -= soft_prompt_tensor.shape[1] - return max_length - -def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) - if args.cpu or args.flexgen: - return input_ids - elif args.deepspeed: - return input_ids.to(device=local_rank) - else: - return input_ids.cuda() - -def decode(output_ids): - reply = tokenizer.decode(output_ids, skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - return reply - -def formatted_outputs(reply, model_name): - if not (args.chat or args.cai_chat): - if model_name.lower().startswith('galactica'): - reply = fix_galactica(reply) - return reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): - reply = fix_gpt4chan(reply) - return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) - else: - return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) - else: - return reply - -def generate_softprompt_input_tensors(input_ids): - inputs_embeds = model.transformer.wte(input_ids) - inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) - filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) - filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens - return inputs_embeds, filler_input_ids - -def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): - global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor - - original_question = question - if not (args.chat or args.cai_chat): - question = apply_extensions(question, "input") - if args.verbose: - print(f"\n\n{question}\n--------------------\n") - - input_ids = encode(question, tokens) - cuda = "" if (args.cpu or args.deepspeed or args.flexgen) else ".cuda()" - if not args.flexgen: - n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] - else: - n = tokenizer(eos_token).input_ids[0] if eos_token else None - - if stopping_string is not None: - # The stopping_criteria code below was copied from - # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py - t = encode(stopping_string, 0, add_special_tokens=False) - stopping_criteria_list = transformers.StoppingCriteriaList([ - _SentinelTokenStoppingCriteria( - sentinel_token_ids=t, - starting_idx=len(input_ids[0]) - ) - ]) - else: - stopping_criteria_list = None - - if not args.flexgen: - generate_params = [ - f"eos_token_id={n}", - f"stopping_criteria=stopping_criteria_list", - f"do_sample={do_sample}", - f"temperature={temperature}", - f"top_p={top_p}", - f"typical_p={typical_p}", - f"repetition_penalty={repetition_penalty}", - f"top_k={top_k}", - f"min_length={min_length if args.no_stream else 0}", - f"no_repeat_ngram_size={no_repeat_ngram_size}", - f"num_beams={num_beams}", - f"penalty_alpha={penalty_alpha}", - f"length_penalty={length_penalty}", - f"early_stopping={early_stopping}", - ] - else: - generate_params = [ - f"do_sample={do_sample}", - f"temperature={temperature}", - f"stop={n}", - ] - - if args.deepspeed: - generate_params.append("synced_gpus=True") - if args.no_stream: - generate_params.append(f"max_new_tokens=tokens") - else: - generate_params.append(f"max_new_tokens=8") - - if soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "filler_input_ids") - else: - generate_params.insert(0, "input_ids") - - # Generate the entire reply at once - if args.no_stream: - t0 = time.time() - with torch.no_grad(): - output = eval(f"model.generate({', '.join(generate_params)}){cuda}")[0] - if soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = decode(output) - if not (args.chat or args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - yield formatted_outputs(reply, model_name) - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") - - # Generate the reply 8 tokens at a time - else: - yield formatted_outputs(original_question, model_name) - for i in tqdm(range(tokens//8+1)): - with torch.no_grad(): - output = eval(f"model.generate({', '.join(generate_params)}){cuda}")[0] - if soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = decode(output) - if not (args.chat or args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - yield formatted_outputs(reply, model_name) - - if not args.flexgen: - input_ids = torch.reshape(output, (1, output.shape[0])) - else: - input_ids = np.reshape(output, (1, output.shape[0])) - if soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - - if output[-1] == n: - break - -def apply_extensions(text, typ): - global available_extensions, extension_state - for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): - if extension_state[ext][0] == True: - ext_string = f"extensions.{ext}.script" - if typ == "input" and hasattr(eval(ext_string), "input_modifier"): - text = eval(f"{ext_string}.input_modifier(text)") - elif typ == "output" and hasattr(eval(ext_string), "output_modifier"): - text = eval(f"{ext_string}.output_modifier(text)") - elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"): - text = eval(f"{ext_string}.bot_prefix_modifier(text)") - return text - -def update_extensions_parameters(*kwargs): - i = 0 - for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): - if extension_state[ext][0] == True: - params = eval(f"extensions.{ext}.script.params") - for param in params: - if len(kwargs) >= i+1: - params[param] = eval(f"kwargs[{i}]") - i += 1 - -def get_available_models(): - return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) - -def get_available_presets(): - return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) - -def get_available_characters(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) - -def get_available_extensions(): - return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) - -def get_available_softprompts(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) - -def create_extensions_block(): - extensions_ui_elements = [] - default_values = [] - if not (args.chat or args.cai_chat): - gr.Markdown('## Extensions parameters') - for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): - if extension_state[ext][0] == True: - params = eval(f"extensions.{ext}.script.params") - for param in params: - _id = f"{ext}-{param}" - default_value = settings[_id] if _id in settings else params[param] - default_values.append(default_value) - if type(params[param]) == str: - extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}")) - elif type(params[param]) in [int, float]: - extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}")) - elif type(params[param]) == bool: - extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}")) - - update_extensions_parameters(*default_values) - btn_extensions = gr.Button("Apply") - btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) + return name def create_settings_menus(): - generate_params = load_preset_values(settings[f'preset{suffix}'] if not args.flexgen else 'Naive', return_dict=True) + generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True) with gr.Row(): with gr.Column(): with gr.Row(): - model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') - create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") + model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') + ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") with gr.Column(): with gr.Row(): - preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not args.flexgen else 'Naive', label='Generation parameters preset') - create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") + preset_menu = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') + ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"): with gr.Row(): @@ -531,7 +119,7 @@ def create_settings_menus(): no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size") with gr.Row(): typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p") - min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream) + min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if shared.args.no_stream else 0, label="min_length", interactive=shared.args.no_stream) gr.Markdown("Contrastive search:") penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha") @@ -545,7 +133,7 @@ def create_settings_menus(): with gr.Accordion("Soft prompt", open=False, elem_id="accordion"): with gr.Row(): softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') - create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") + ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") gr.Markdown('Upload a soft prompt (.zip format):') with gr.Row(): @@ -557,381 +145,18 @@ def create_settings_menus(): upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu]) return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping -# This gets the new line characters right. -def clean_chat_message(text): - text = text.replace('\n', '\n\n') - text = re.sub(r"\n{3,}", "\n\n", text) - text = text.strip() - return text - -def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False): - global soft_prompt, soft_prompt_tensor - - text = clean_chat_message(text) - rows = [f"{context.strip()}\n"] - i = len(history['internal'])-1 - count = 0 - - if soft_prompt: - chat_prompt_size -= soft_prompt_tensor.shape[1] - max_length = min(get_max_prompt_length(tokens), chat_prompt_size) - - while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: - rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") - count += 1 - if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): - rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n") - count += 1 - i -= 1 - - if not impersonate: - rows.append(f"{name1}: {text}\n") - rows.append(apply_extensions(f"{name2}:", "bot_prefix")) - limit = 3 - else: - rows.append(f"{name1}:") - limit = 2 - - while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: - rows.pop(1) - rows.pop(1) - - question = ''.join(rows) - return question - -def extract_message_from_reply(question, reply, current, other, check, extensions=False): - next_character_found = False - substring_found = False - - previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)] - idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)] - idx = idx[len(previous_idx)-1] - - if extensions: - reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):] - else: - reply = reply[idx + 1 + len(f"{current}:"):] - - if check: - reply = reply.split('\n')[0].strip() - else: - idx = reply.find(f"\n{other}:") - if idx != -1: - reply = reply[:idx] - next_character_found = True - reply = clean_chat_message(reply) - - # Detect if something like "\nYo" is generated just before - # "\nYou:" is completed - tmp = f"\n{other}:" - for j in range(1, len(tmp)): - if reply[-j:] == tmp[:j]: - substring_found = True - - return reply, next_character_found, substring_found - -def generate_chat_picture(picture, name1, name2): - text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*' - buffer = BytesIO() - picture.save(buffer, format="JPEG") - img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') - visible_text = f'' - return text, visible_text - -def stop_everything_event(): - global stop_everything - stop_everything = True - -def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - global stop_everything - stop_everything = False - - if 'pygmalion' in model_name.lower(): - name1 = "You" - - if args.picture and picture is not None: - text, visible_text = generate_chat_picture(picture, name1, name2) - else: - visible_text = text - if args.chat: - visible_text = visible_text.replace('\n', '
') - - text = apply_extensions(text, "input") - question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size) - eos_token = '\n' if check else None - first = True - for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): - reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) - visible_reply = apply_extensions(reply, "output") - if args.chat: - visible_reply = visible_reply.replace('\n', '
') - - # We need this global variable to handle the Stop event, - # otherwise gradio gets confused - if stop_everything: - return history['visible'] - - if first: - first = False - history['internal'].append(['', '']) - history['visible'].append(['', '']) - - history['internal'][-1] = [text, reply] - history['visible'][-1] = [visible_text, visible_reply] - if not substring_found: - yield history['visible'] - if next_character_found: - break - yield history['visible'] - -def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - if 'pygmalion' in model_name.lower(): - name1 = "You" - - question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) - eos_token = '\n' if check else None - for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): - reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) - if not substring_found: - yield reply - if next_character_found: - break - yield reply - -def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): - yield generate_chat_html(_history, name1, name2, character) - -def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - if character is not None and len(history['visible']) == 1: - if args.cai_chat: - yield generate_chat_html(history['visible'], name1, name2, character) - else: - yield history['visible'] - else: - last_visible = history['visible'].pop() - last_internal = history['internal'].pop() - - for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): - if args.cai_chat: - history['visible'][-1] = [last_visible[0], _history[-1][1]] - yield generate_chat_html(history['visible'], name1, name2, character) - else: - history['visible'][-1] = (last_visible[0], _history[-1][1]) - yield history['visible'] - -def remove_last_message(name1, name2): - if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': - last = history['visible'].pop() - history['internal'].pop() - else: - last = ['', ''] - if args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character), last[0] - else: - return history['visible'], last[0] - -def send_last_reply_to_input(): - if len(history['internal']) > 0: - return history['internal'][-1][1] - else: - return '' - -def replace_last_reply(text, name1, name2): - if len(history['visible']) > 0: - if args.cai_chat: - history['visible'][-1][1] = text - else: - history['visible'][-1] = (history['visible'][-1][0], text) - history['internal'][-1][1] = apply_extensions(text, "input") - - if args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character) - else: - return history['visible'] - -def clear_html(): - return generate_chat_html([], "", "", character) - -def clear_chat_log(_character, name1, name2): - global history - if _character != 'None': - for i in range(len(history['internal'])): - if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]: - history['visible'] = [['', history['internal'][i][1]]] - history['internal'] = history['internal'][:i+1] - break - else: - history['internal'] = [] - history['visible'] = [] - if args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character) - else: - return history['visible'] - -def redraw_html(name1, name2): - global history - return generate_chat_html(history['visible'], name1, name2, character) - -def tokenize_dialogue(dialogue, name1, name2): - _history = [] - - dialogue = re.sub('', '', dialogue) - dialogue = re.sub('', '', dialogue) - dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) - dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue) - idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)] - if len(idx) == 0: - return _history - - messages = [] - for i in range(len(idx)-1): - messages.append(dialogue[idx[i]:idx[i+1]].strip()) - messages.append(dialogue[idx[-1]:].strip()) - - entry = ['', ''] - for i in messages: - if i.startswith(f'{name1}:'): - entry[0] = i[len(f'{name1}:'):].strip() - elif i.startswith(f'{name2}:'): - entry[1] = i[len(f'{name2}:'):].strip() - if not (len(entry[0]) == 0 and len(entry[1]) == 0): - _history.append(entry) - entry = ['', ''] - - print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') - for row in _history: - for column in row: - print("\n") - for line in column.strip().split('\n'): - print("| "+line+"\n") - print("|\n") - print("------------------------------") - - return _history - -def save_history(timestamp=True): - if timestamp: - fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" - else: - fname = f"{character or ''}{'_' if character else ''}persistent.json" - if not Path('logs').exists(): - Path('logs').mkdir() - with open(Path(f'logs/{fname}'), 'w') as f: - f.write(json.dumps({'data': history['internal'], 'data_visible': history['visible']}, indent=2)) - return Path(f'logs/{fname}') - -def load_history(file, name1, name2): - global history - file = file.decode('utf-8') - try: - j = json.loads(file) - if 'data' in j: - history['internal'] = j['data'] - if 'data_visible' in j: - history['visible'] = j['data_visible'] - else: - history['visible'] = copy.deepcopy(history['internal']) - # Compatibility with Pygmalion AI's official web UI - elif 'chat' in j: - history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] - if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): - history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)] - history['visible'] = copy.deepcopy(history['internal']) - history['visible'][0][0] = '' - else: - history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)] - history['visible'] = copy.deepcopy(history['internal']) - except: - history['internal'] = tokenize_dialogue(file, name1, name2) - history['visible'] = copy.deepcopy(history['internal']) - -def load_character(_character, name1, name2): - global history, character - context = "" - history['internal'] = [] - history['visible'] = [] - if _character != 'None': - character = _character - data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) - name2 = data['char_name'] - if 'char_persona' in data and data['char_persona'] != '': - context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" - if 'world_scenario' in data and data['world_scenario'] != '': - context += f"Scenario: {data['world_scenario']}\n" - context = f"{context.strip()}\n\n" - if 'example_dialogue' in data and data['example_dialogue'] != '': - history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2) - if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] - history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]] - else: - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] - history['visible'] += [['', "Hello there!"]] - else: - character = None - context = settings['context_pygmalion'] - name2 = settings['name2_pygmalion'] - - if Path(f'logs/{character}_persistent.json').exists(): - load_history(open(Path(f'logs/{character}_persistent.json'), 'rb').read(), name1, name2) - - if args.cai_chat: - return name2, context, generate_chat_html(history['visible'], name1, name2, character) - else: - return name2, context, history['visible'] - -def upload_character(json_file, img, tavern=False): - json_file = json_file if type(json_file) == str else json_file.decode('utf-8') - data = json.loads(json_file) - outfile_name = data["char_name"] - i = 1 - while Path(f'characters/{outfile_name}.json').exists(): - outfile_name = f'{data["char_name"]}_{i:03d}' - i += 1 - if tavern: - outfile_name = f'TavernAI-{outfile_name}' - with open(Path(f'characters/{outfile_name}.json'), 'w') as f: - f.write(json_file) - if img is not None: - img = Image.open(io.BytesIO(img)) - img.save(Path(f'characters/{outfile_name}.png')) - print(f'New character saved to "characters/{outfile_name}.json".') - return outfile_name - -def upload_tavern_character(img, name1, name2): - _img = Image.open(io.BytesIO(img)) - _img.getexif() - decoded_string = base64.b64decode(_img.info['chara']) - _json = json.loads(decoded_string) - _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} - _json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name']) - return upload_character(json.dumps(_json), img, tavern=True) - -def upload_your_profile_picture(img): - img = Image.open(io.BytesIO(img)) - img.save(Path(f'img_me.png')) - print(f'Profile picture saved to "img_me.png"') - -# Global variables available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() -available_extensions = get_available_extensions() available_softprompts = get_available_softprompts() -extension_state = {} -if args.extensions is not None: - for i,ext in enumerate(args.extensions.split(',')): - if ext in available_extensions: - print(f'Loading the extension "{ext}"... ', end='') - ext_string = f"extensions.{ext}.script" - exec(f"import {ext_string}") - extension_state[ext] = [True, i] - print(f'Ok.') + +extensions_module.available_extensions = get_available_extensions() +if shared.args.extensions is not None: + extensions_module.load_extensions() # Choosing the default model -if args.model is not None: - model_name = args.model +if shared.args.model is not None: + shared.model_name = shared.args.model else: if len(available_models) == 0: print("No models are available! Please download at least one.") @@ -940,43 +165,35 @@ else: i = 0 else: print("The following models are available:\n") - for i,model in enumerate(available_models): + for i, model in enumerate(available_models): print(f"{i+1}. {model}") print(f"\nWhich one do you want to load? 1-{len(available_models)}\n") i = int(input())-1 print() - model_name = available_models[i] -model, tokenizer = load_model(model_name) -loaded_preset = None -soft_prompt_tensor = None -soft_prompt = False -stop_everything = False + shared.model_name = available_models[i] +shared.model, shared.tokenizer = load_model(shared.model_name) # UI settings -if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): - default_text = settings['prompt_gpt4chan'] -elif re.match('(rosey|chip|joi)_.*_instruct.*', model_name.lower()) is not None: - default_text = 'User: \n' -else: - default_text = settings['prompt'] -description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" - -suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else '' buttons = {} gen_events = [] -history = {'internal': [], 'visible': []} -character = None - -if args.chat or args.cai_chat: +suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' +description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" +if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + default_text = shared.settings['prompt_gpt4chan'] +elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None: + default_text = 'User: \n' +else: + default_text = shared.settings['prompt'] +if shared.args.chat or shared.args.cai_chat: if Path(f'logs/persistent.json').exists(): - load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}']) + chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']) - with gr.Blocks(css=css+chat_css, analytics_enabled=False) as interface: - if args.cai_chat: - display = gr.HTML(value=generate_chat_html(history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], character)) + with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface: + if shared.args.cai_chat: + display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) else: - display = gr.Chatbot(value=history['visible']) + display = gr.Chatbot(value=shared.history['visible']) textbox = gr.Textbox(label='Input') with gr.Row(): buttons["Stop"] = gr.Button("Stop") @@ -989,20 +206,20 @@ if args.chat or args.cai_chat: with gr.Row(): buttons["Send last reply to input"] = gr.Button("Send last reply to input") buttons["Replace last reply"] = gr.Button("Replace last reply") - if args.picture: + if shared.args.picture: with gr.Row(): picture_select = gr.Image(label="Send a picture", type='pil') with gr.Tab("Chat settings"): - name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name') - name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name') - context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context') + name1 = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') + name2 = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') + context = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context') with gr.Row(): character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character') - create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") + ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") with gr.Row(): - check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') + check = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') with gr.Row(): with gr.Tab('Chat history'): with gr.Row(): @@ -1030,59 +247,59 @@ if args.chat or args.cai_chat: with gr.Tab("Generation settings"): with gr.Row(): with gr.Column(): - max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) + max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) with gr.Column(): - chat_prompt_size_slider = gr.Slider(minimum=settings['chat_prompt_size_min'], maximum=settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=settings['chat_prompt_size']) + chat_prompt_size_slider = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() - if args.extensions is not None: + if shared.args.extensions is not None: with gr.Tab("Extensions"): - create_extensions_block() + extensions_module.create_extensions_block() input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] - if args.picture: + if shared.args.picture: input_params.append(picture_select) - function_call = "cai_chatbot_wrapper" if args.cai_chat else "chatbot_wrapper" + function_call = "chat.cai_chatbot_wrapper" if shared.args.cai_chat else "chat.chatbot_wrapper" - gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=args.no_stream)) - if args.picture: - picture_select.upload(eval(function_call), input_params, display, show_progress=args.no_stream) - gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream)) - gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream)) - buttons["Stop"].click(stop_everything_event, [], [], cancels=gen_events) + gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=shared.args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=shared.args.no_stream)) + if shared.args.picture: + picture_select.upload(eval(function_call), input_params, display, show_progress=shared.args.no_stream) + gen_events.append(buttons["Regenerate"].click(chat.regenerate_wrapper, input_params, display, show_progress=shared.args.no_stream)) + gen_events.append(buttons["Impersonate"].click(chat.impersonate_wrapper, input_params, textbox, show_progress=shared.args.no_stream)) + buttons["Stop"].click(chat.stop_everything_event, [], [], cancels=gen_events) - buttons["Send last reply to input"].click(send_last_reply_to_input, [], textbox, show_progress=args.no_stream) - buttons["Replace last reply"].click(replace_last_reply, [textbox, name1, name2], display, show_progress=args.no_stream) - buttons["Clear history"].click(clear_chat_log, [character_menu, name1, name2], display) - buttons["Remove last"].click(remove_last_message, [name1, name2], [display, textbox], show_progress=False) - buttons["Download"].click(save_history, inputs=[], outputs=[download]) - buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu]) + buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream) + buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream) + buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display) + buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False) + buttons["Download"].click(chat.save_history, inputs=[], outputs=[download]) + buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu]) # Clearing stuff and saving the history for i in ["Generate", "Regenerate", "Replace last reply"]: buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) - buttons[i].click(lambda : save_history(timestamp=False), [], [], show_progress=False) - buttons["Clear history"].click(lambda : save_history(timestamp=False), [], [], show_progress=False) + buttons[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + buttons["Clear history"].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False) - textbox.submit(lambda : save_history(timestamp=False), [], [], show_progress=False) + textbox.submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display]) - upload_chat_history.upload(load_history, [upload_chat_history, name1, name2], []) - upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) - upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) - if args.picture: + character_menu.change(chat.load_character, [character_menu, name1, name2], [name2, context, display]) + upload_chat_history.upload(chat.load_history, [upload_chat_history, name1, name2], []) + upload_img_tavern.upload(chat.upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) + upload_img_me.upload(chat.upload_your_profile_picture, [upload_img_me], []) + if shared.args.picture: picture_select.upload(lambda : None, [], [picture_select], show_progress=False) - if args.cai_chat: - upload_chat_history.upload(redraw_html, [name1, name2], [display]) - upload_img_me.upload(redraw_html, [name1, name2], [display]) + if shared.args.cai_chat: + upload_chat_history.upload(chat.redraw_html, [name1, name2], [display]) + upload_img_me.upload(chat.redraw_html, [name1, name2], [display]) else: - upload_chat_history.upload(lambda : history['visible'], [], [display]) - upload_img_me.upload(lambda : history['visible'], [], [display]) + upload_chat_history.upload(lambda : shared.history['visible'], [], [display]) + upload_img_me.upload(lambda : shared.history['visible'], [], [display]) -elif args.notebook: - with gr.Blocks(css=css, analytics_enabled=False) as interface: +elif shared.args.notebook: + with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: gr.Markdown(description) with gr.Tab('Raw'): textbox = gr.Textbox(value=default_text, lines=23) @@ -1094,24 +311,24 @@ elif args.notebook: buttons["Generate"] = gr.Button("Generate") buttons["Stop"] = gr.Button("Stop") - max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) + max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() - if args.extensions is not None: - create_extensions_block() + if shared.args.extensions is not None: + extensions_module.create_extensions_block() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream)) buttons["Stop"].click(None, None, None, cancels=gen_events) else: - with gr.Blocks(css=css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: gr.Markdown(description) with gr.Row(): with gr.Column(): textbox = gr.Textbox(value=default_text, lines=15, label='Input') - max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) + max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) buttons["Generate"] = gr.Button("Generate") with gr.Row(): with gr.Column(): @@ -1120,8 +337,8 @@ else: buttons["Stop"] = gr.Button("Stop") preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() - if args.extensions is not None: - create_extensions_block() + if shared.args.extensions is not None: + extensions_module.create_extensions_block() with gr.Column(): with gr.Tab('Raw'): @@ -1131,16 +348,16 @@ else: with gr.Tab('HTML'): html = gr.HTML() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) - gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) + gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream)) buttons["Stop"].click(None, None, None, cancels=gen_events) interface.queue() -if args.listen: - interface.launch(prevent_thread_lock=True, share=args.share, server_name="0.0.0.0", server_port=args.listen_port) +if shared.args.listen: + interface.launch(prevent_thread_lock=True, share=shared.args.share, server_name="0.0.0.0", server_port=shared.args.listen_port) else: - interface.launch(prevent_thread_lock=True, share=args.share, server_port=args.listen_port) + interface.launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) # I think that I will need this later while True: