From bd8aac8fa43daa7bd0e2d3d2e446a403a447c744 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 4 Mar 2023 13:28:42 -0300 Subject: [PATCH 01/35] Add LLaMA 8-bit support --- modules/LLaMA_8bit.py | 125 ++++++++++++++++++++++++++++++++++++++++++ modules/models.py | 16 ++++-- 2 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 modules/LLaMA_8bit.py diff --git a/modules/LLaMA_8bit.py b/modules/LLaMA_8bit.py new file mode 100644 index 00000000..a339277c --- /dev/null +++ b/modules/LLaMA_8bit.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import json + +from pathlib import Path + +from fairscale.nn.model_parallel.initialize import initialize_model_parallel + +from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA + + +def setup_model_parallel() -> Tuple[int, int]: + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + torch.distributed.init_process_group("nccl") + initialize_model_parallel(world_size) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + return local_rank, world_size + + +def load( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, +) -> LLaMA: + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + # torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_tensor_type(torch.HalfTensor) + print("Creating transformer") + model = Transformer(model_args) + print("Transformer created") + + key_to_dim = { + "w1": 0, + "w2": -1, + "w3": 0, + "wo": -1, + "wq": 0, + "wk": 0, + "wv": 0, + "output": 0, + "tok_embeddings": -1, + "ffn_norm": None, + "attention_norm": None, + "norm": None, + "rope": None, + } + + # ? + torch.set_default_tensor_type(torch.FloatTensor) + + # load the state dict incrementally, to avoid memory problems + for i, ckpt in enumerate(checkpoints): + print(f"Loading checkpoint {i}") + checkpoint = torch.load(ckpt, map_location="cpu") + for parameter_name, parameter in model.named_parameters(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] is None and i == 0: + parameter.data = checkpoint[parameter_name] + elif key_to_dim[short_name] == 0: + size = checkpoint[parameter_name].size(0) + parameter.data[size * i : size * (i + 1), :] = checkpoint[ + parameter_name + ] + elif key_to_dim[short_name] == -1: + size = checkpoint[parameter_name].size(-1) + parameter.data[:, size * i : size * (i + 1)] = checkpoint[ + parameter_name + ] + del checkpoint + + # model.load_state_dict(checkpoint, strict=False) + model.quantize() + + generator = LLaMA(model, tokenizer) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator + + +class LLaMAModel_8bit: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1): + tokenizer_path = path / "tokenizer.model" + path = os.path.abspath(path) + tokenizer_path = os.path.abspath(tokenizer_path) + + generator = load(path, tokenizer_path, max_seq_len, max_batch_size) + + result = self() + result.pipeline = generator + return result + + def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95): + + results = self.pipeline.generate( + [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p + ) + + return results[0] + diff --git a/modules/models.py b/modules/models.py index 904d8ae2..c7b75bb9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -88,12 +88,20 @@ def load_model(model_name): # LLaMA model (not on HuggingFace) elif shared.is_LLaMA: - import modules.LLaMA - from modules.LLaMA import LLaMAModel + if shared.args.load_in_8bit: + import modules.LLaMA_8bit + from modules.LLaMA_8bit import LLaMAModel_8bit - model = LLaMAModel.from_pretrained(Path(f'models/{model_name}')) + model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}')) - return model, None + return model, None + else: + import modules.LLaMA + from modules.LLaMA import LLaMAModel + + model = LLaMAModel.from_pretrained(Path(f'models/{model_name}')) + + return model, None # Custom else: From c33715ad5b32e59bf61c5ca3569a9a890b3afc81 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 01:20:31 -0300 Subject: [PATCH 02/35] Move towards HF LLaMA implementation --- modules/LLaMA.py | 96 ---------------------------- modules/LLaMA_8bit.py | 125 ------------------------------------- modules/models.py | 20 +----- modules/shared.py | 2 - modules/text_generation.py | 4 +- requirements.txt | 2 +- 6 files changed, 4 insertions(+), 245 deletions(-) delete mode 100644 modules/LLaMA.py delete mode 100644 modules/LLaMA_8bit.py diff --git a/modules/LLaMA.py b/modules/LLaMA.py deleted file mode 100644 index 3781ccf5..00000000 --- a/modules/LLaMA.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the GNU General Public License version 3. - -import json -import os -import sys -import time -from pathlib import Path -from typing import Tuple - -import fire -import torch -from fairscale.nn.model_parallel.initialize import initialize_model_parallel -from llama import LLaMA, ModelArgs, Tokenizer, Transformer - -os.environ['RANK'] = '0' -os.environ['WORLD_SIZE'] = '1' -os.environ['MP'] = '1' -os.environ['MASTER_ADDR'] = '127.0.0.1' -os.environ['MASTER_PORT'] = '2223' - -def setup_model_parallel() -> Tuple[int, int]: - local_rank = int(os.environ.get("LOCAL_RANK", -1)) - world_size = int(os.environ.get("WORLD_SIZE", -1)) - - torch.distributed.init_process_group("gloo") - initialize_model_parallel(world_size) - torch.cuda.set_device(local_rank) - - # seed must be the same in all processes - torch.manual_seed(1) - return local_rank, world_size - -def load( - ckpt_dir: str, - tokenizer_path: str, - local_rank: int, - world_size: int, - max_seq_len: int, - max_batch_size: int, -) -> LLaMA: - start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert world_size == len( - checkpoints - ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - ckpt_path = checkpoints[local_rank] - print("Loading") - checkpoint = torch.load(ckpt_path, map_location="cpu") - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params - ) - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = Transformer(model_args) - torch.set_default_tensor_type(torch.FloatTensor) - model.load_state_dict(checkpoint, strict=False) - - generator = LLaMA(model, tokenizer) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - return generator - - -class LLaMAModel: - def __init__(self): - pass - - @classmethod - def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1): - tokenizer_path = path / "tokenizer.model" - path = os.path.abspath(path) - tokenizer_path = os.path.abspath(tokenizer_path) - - local_rank, world_size = setup_model_parallel() - if local_rank > 0: - sys.stdout = open(os.devnull, "w") - - generator = load( - path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size - ) - - result = self() - result.pipeline = generator - return result - - def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95): - - results = self.pipeline.generate( - [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p - ) - - return results[0] diff --git a/modules/LLaMA_8bit.py b/modules/LLaMA_8bit.py deleted file mode 100644 index a339277c..00000000 --- a/modules/LLaMA_8bit.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the GNU General Public License version 3. - -from typing import Tuple -import os -import sys -import torch -import fire -import time -import json - -from pathlib import Path - -from fairscale.nn.model_parallel.initialize import initialize_model_parallel - -from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA - - -def setup_model_parallel() -> Tuple[int, int]: - local_rank = int(os.environ.get("LOCAL_RANK", -1)) - world_size = int(os.environ.get("WORLD_SIZE", -1)) - - torch.distributed.init_process_group("nccl") - initialize_model_parallel(world_size) - torch.cuda.set_device(local_rank) - - # seed must be the same in all processes - torch.manual_seed(1) - return local_rank, world_size - - -def load( - ckpt_dir: str, - tokenizer_path: str, - max_seq_len: int, - max_batch_size: int, -) -> LLaMA: - start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params - ) - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - # torch.set_default_tensor_type(torch.cuda.HalfTensor) - torch.set_default_tensor_type(torch.HalfTensor) - print("Creating transformer") - model = Transformer(model_args) - print("Transformer created") - - key_to_dim = { - "w1": 0, - "w2": -1, - "w3": 0, - "wo": -1, - "wq": 0, - "wk": 0, - "wv": 0, - "output": 0, - "tok_embeddings": -1, - "ffn_norm": None, - "attention_norm": None, - "norm": None, - "rope": None, - } - - # ? - torch.set_default_tensor_type(torch.FloatTensor) - - # load the state dict incrementally, to avoid memory problems - for i, ckpt in enumerate(checkpoints): - print(f"Loading checkpoint {i}") - checkpoint = torch.load(ckpt, map_location="cpu") - for parameter_name, parameter in model.named_parameters(): - short_name = parameter_name.split(".")[-2] - if key_to_dim[short_name] is None and i == 0: - parameter.data = checkpoint[parameter_name] - elif key_to_dim[short_name] == 0: - size = checkpoint[parameter_name].size(0) - parameter.data[size * i : size * (i + 1), :] = checkpoint[ - parameter_name - ] - elif key_to_dim[short_name] == -1: - size = checkpoint[parameter_name].size(-1) - parameter.data[:, size * i : size * (i + 1)] = checkpoint[ - parameter_name - ] - del checkpoint - - # model.load_state_dict(checkpoint, strict=False) - model.quantize() - - generator = LLaMA(model, tokenizer) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - return generator - - -class LLaMAModel_8bit: - def __init__(self): - pass - - @classmethod - def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1): - tokenizer_path = path / "tokenizer.model" - path = os.path.abspath(path) - tokenizer_path = os.path.abspath(tokenizer_path) - - generator = load(path, tokenizer_path, max_seq_len, max_batch_size) - - result = self() - result.pipeline = generator - return result - - def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95): - - results = self.pipeline.generate( - [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p - ) - - return results[0] - diff --git a/modules/models.py b/modules/models.py index c7b75bb9..40feb8b3 100644 --- a/modules/models.py +++ b/modules/models.py @@ -39,10 +39,9 @@ def load_model(model_name): t0 = time.time() shared.is_RWKV = model_name.lower().startswith('rwkv-') - shared.is_LLaMA = model_name.lower().startswith('llama-') # Default settings - if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA): + if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: @@ -86,23 +85,6 @@ def load_model(model_name): return model, None - # LLaMA model (not on HuggingFace) - elif shared.is_LLaMA: - if shared.args.load_in_8bit: - import modules.LLaMA_8bit - from modules.LLaMA_8bit import LLaMAModel_8bit - - model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}')) - - return model, None - else: - import modules.LLaMA - from modules.LLaMA import LLaMAModel - - model = LLaMAModel.from_pretrained(Path(f'models/{model_name}')) - - return model, None - # Custom else: command = "AutoModelForCausalLM.from_pretrained" diff --git a/modules/shared.py b/modules/shared.py index e9dfdaa2..29276fd3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -6,7 +6,6 @@ model_name = "" soft_prompt_tensor = None soft_prompt = False is_RWKV = False -is_LLaMA = False # Chat variables history = {'internal': [], 'visible': []} @@ -43,7 +42,6 @@ settings = { 'default': 'NovelAI-Sphinx Moth', 'pygmalion-*': 'Pygmalion', 'RWKV-*': 'Naive', - 'llama-*': 'Naive', '(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)' }, 'prompts': { diff --git a/modules/text_generation.py b/modules/text_generation.py index f9082a31..ee93fb7c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -24,7 +24,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): # These models do not have explicit tokenizers for now, so # we return an estimate for the number of tokens - if shared.is_RWKV or shared.is_LLaMA: + if shared.is_RWKV: return np.zeros((1, len(prompt)//4)) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) @@ -90,7 +90,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier - if shared.is_RWKV or shared.is_LLaMA: + if shared.is_RWKV: if shared.args.no_stream: reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p) t1 = time.time() diff --git a/requirements.txt b/requirements.txt index 55aeb8fd..70dc8349 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ gradio==3.18.0 numpy rwkv==0.0.6 safetensors==0.2.8 -git+https://github.com/huggingface/transformers +git+https://github.com/oobabooga/transformers@llama_push From 5492e2e9f81fce8c0b45ef966459a1cb0a635a61 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 10:02:24 -0300 Subject: [PATCH 03/35] Add sentencepiece --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 70dc8349..e254336b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ gradio==3.18.0 numpy rwkv==0.0.6 safetensors==0.2.8 +sentencepiece git+https://github.com/oobabooga/transformers@llama_push From 8e706df20e7a22d695e1e99e8bc393ff0ced74ff Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 10:12:43 -0300 Subject: [PATCH 04/35] Fix a memory leak when text streaming is on --- modules/text_generation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index ee93fb7c..7b6bdcf5 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -81,11 +81,13 @@ def formatted_outputs(reply, model_name): else: return reply -def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): +def clear_torch_cache(): gc.collect() if not shared.args.cpu: torch.cuda.empty_cache() +def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): + clear_torch_cache() t0 = time.time() # These models are not part of Hugging Face, so we handle them @@ -98,6 +100,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) else: for i in tqdm(range(max_new_tokens//8+1)): + clear_torch_cache() reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) yield formatted_outputs(reply, shared.model_name) question = reply @@ -183,6 +186,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi else: yield formatted_outputs(original_question, shared.model_name) for i in tqdm(range(max_new_tokens//8+1)): + clear_torch_cache() + with torch.no_grad(): output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] if shared.soft_prompt: From a54b91af778ffb89193874a11ede74a0b1b0cd41 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 10:21:15 -0300 Subject: [PATCH 05/35] Improve readability --- modules/text_generation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 7b6bdcf5..e6207b56 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -115,7 +115,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" n = shared.tokenizer.eos_token_id if eos_token is None else encode(eos_token)[0][-1] - if stopping_string is not None: # The stopping_criteria code below was copied from # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py @@ -152,14 +151,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi f"temperature={temperature}", f"stop={n}", ] - if shared.args.deepspeed: generate_params.append("synced_gpus=True") if shared.args.no_stream: generate_params.append("max_new_tokens=max_new_tokens") else: generate_params.append("max_new_tokens=8") - if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) generate_params.insert(0, "inputs_embeds=inputs_embeds") From 2af66a4d4cd570f22a5d1e1b509f767828357089 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 16:08:50 -0300 Subject: [PATCH 06/35] Fix in pygmalion replies --- modules/chat.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 3b4cbba3..8e64612e 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -51,23 +51,29 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat prompt = ''.join(rows) return prompt -def extract_message_from_reply(question, reply, current, other, check, extensions=False): +def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False): next_character_found = False substring_found = False - previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)] - idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)] - idx = idx[len(previous_idx)-1] + asker = name1 if not impersonate else name2 + replier = name2 if not impersonate else name1 - if extensions: - reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):] + previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", question)] + idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", reply)] + idx = idx[max(len(previous_idx)-1, 0)] + + if not impersonate: + reply = reply[idx + 1 + len(apply_extensions(f"{replier}:", "bot_prefix")):] else: - reply = reply[idx + 1 + len(f"{current}:"):] + reply = reply[idx + 1 + len(f"{replier}:"):] if check: - reply = reply.split('\n')[0].strip() + lines = reply.split('\n') + reply = lines[0].strip() + if len(lines) > 1: + next_character_found = True else: - idx = reply.find(f"\n{other}:") + idx = reply.find(f"\n{asker}:") if idx != -1: reply = reply[:idx] next_character_found = True @@ -75,7 +81,7 @@ def extract_message_from_reply(question, reply, current, other, check, extension # Detect if something like "\nYo" is generated just before # "\nYou:" is completed - tmp = f"\n{other}:" + tmp = f"\n{asker}:" for j in range(1, len(tmp)): if reply[-j:] == tmp[:j]: substring_found = True @@ -89,6 +95,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical shared.stop_everything = False just_started = True eos_token = '\n' if check else None + name1_original = name1 if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -119,7 +126,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply - reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True) + reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check) + reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(reply, "output") if shared.args.chat: visible_reply = visible_reply.replace('\n', '
') @@ -139,6 +147,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] if next_character_found: break + yield shared.history['visible'] def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): @@ -152,7 +161,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ reply = '' for i in range(chat_generation_attempts): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): - reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) + reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) if not substring_found: yield reply if next_character_found: From 145c725c395f1bfbb448217ef5bb98412ed9f3ce Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 16:28:21 -0300 Subject: [PATCH 07/35] Bump RWKV version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e254336b..2051dc0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy -rwkv==0.0.6 +rwkv==0.0.7 safetensors==0.2.8 sentencepiece git+https://github.com/oobabooga/transformers@llama_push From c855b828fe48902f72985602bf2c0967a6a298c9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Mar 2023 17:01:47 -0300 Subject: [PATCH 08/35] Better handle --- modules/chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 8e64612e..f40f8299 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -127,8 +127,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Extracting the reply reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check) - reply = re.sub("(||{{user}})", name1_original, reply) - visible_reply = apply_extensions(reply, "output") + visible_reply = re.sub("(||{{user}})", name1_original, reply) + visible_reply = apply_extensions(visible_reply, "output") if shared.args.chat: visible_reply = visible_reply.replace('\n', '
') From 9907bee4a45a9ce9c287b5c294f525aa8bae79b0 Mon Sep 17 00:00:00 2001 From: MetaIX <125941078+MetaIX@users.noreply.github.com> Date: Sun, 5 Mar 2023 19:04:22 -0600 Subject: [PATCH 09/35] Support for Eleven Labs TTS As per your suggestion at https://github.com/oobabooga/text-generation-webui/issues/159 here's my attempt. I'm brand new to python and github. Completely different from unreal + visual coding, so forgive my amateurish code. This essentially adds support for Eleven Labs TTS. Tested it without major issues, and I believe it's functional (hopefully). Extra requirements: elevenlabslib https://github.com/lugia19/elevenlabslib, sounddevice0.4.6, and soundfile Folder structure is the same as the SileroTTS Extension. --- extensions/silero_tts/script.py | 102 ++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index f697d0e2..e088ec8e 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -3,26 +3,55 @@ from pathlib import Path import gradio as gr import torch +import io +import json +import os -torch._C._jit_set_profiling_mode(False) +import requests + +from elevenlabslib.helpers import * +from elevenlabslib import * params = { 'activate': True, - 'speaker': 'en_56', - 'language': 'en', - 'model_id': 'v3_en', - 'sample_rate': 48000, - 'device': 'cpu', + 'api_key': '12345', + 'selected_voice': 'None', } -current_params = params.copy() -voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] +initial_voice = ['None'] wav_idx = 0 +user = ElevenLabsUser(params['api_key']) +user_info = None -def load_model(): - model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) - model.to(params['device']) - return model -model = load_model() + +"Check if the API is valid and refresh the UI accordingly." +def check_valid_api(): + + global user, user_info, params + + user = ElevenLabsUser(params['api_key']) + user_info = user._get_subscription_data() + print('checking api') + if params['activate'] == False: + return gr.update(value='Disconnected') + elif user_info is None: + print('Incorrect API Key') + return gr.update(value='Disconnected') + else: + print('Got an API Key!') + return gr.update(value='Connected') + +"Once the API is verified, get the available voices and update the dropdown list" +def refresh_voices(): + + global user, user_info + + your_voices = [None] + if user_info is not None: + for voice in user.get_available_voices(): + your_voices.append(voice.initialName) + return gr.Dropdown.update(choices=your_voices) + else: + return def remove_surrounded_chars(string): new_string = "" @@ -46,17 +75,12 @@ def output_modifier(string): """ This function is applied to the model outputs. """ - - global wav_idx, model, current_params - - for i in params: - if params[i] != current_params[i]: - model = load_model() - current_params = params.copy() - break - + global params, wav_idx, user, user_info + if params['activate'] == False: return string + elif user_info == None: + return string string = remove_surrounded_chars(string) string = string.replace('"', '') @@ -66,29 +90,31 @@ def output_modifier(string): if string == '': string = 'empty reply, try regenerating' + + output_file = Path('extensions/elevenlabs_tts/outputs/{}.wav'.format(wav_idx)) + voice = user.get_voices_by_name(params['selected_voice'])[0] + audio_data = voice.generate_audio_bytes(string) + save_bytes_to_path("extensions/elevenlabs_tts/outputs/{}.wav".format(wav_idx), audio_data) - output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') - audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) string = f'' wav_idx += 1 - return string -def bot_prefix_modifier(string): - """ - This function is only applied in chat mode. It modifies - the prefix text for the Bot and can be used to bias its - behavior. - """ - - return string def ui(): # Gradio elements - activate = gr.Checkbox(value=params['activate'], label='Activate TTS') - voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') - + with gr.Row(): + activate = gr.Checkbox(value=params['activate'], label='Activate TTS') + connection_status = gr.Textbox(value='Disconnected', label='Connection Status') + voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice') + with gr.Row(): + api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') + connect = gr.Button(value='Connect') # Event functions to update the parameters in the backend - activate.change(lambda x: params.update({"activate": x}), activate, None) - voice.change(lambda x: params.update({"speaker": x}), voice, None) + activate.change(lambda x: params.update({'activate': x}), activate, None) + voice.change(lambda x: params.update({'selected_voice': x}), voice, None) + api_key.change(lambda x: params.update({'api_key': x}), api_key, None) + connect.click(check_valid_api, [], connection_status) + connect.click(refresh_voices, [], voice) + From 53ce21ac68846704004f01a0aac9463ebce92ecb Mon Sep 17 00:00:00 2001 From: Mug <> Date: Mon, 6 Mar 2023 12:13:50 +0100 Subject: [PATCH 10/35] Add api example using websockets --- api-example-stream.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 api-example-stream.py diff --git a/api-example-stream.py b/api-example-stream.py new file mode 100644 index 00000000..0d93b4b6 --- /dev/null +++ b/api-example-stream.py @@ -0,0 +1,81 @@ +import string +import random +import websockets +import json +import asyncio + +def random_hash(): + letters = string.ascii_lowercase + string.digits + return ''.join(random.choice(letters) for i in range(9)) + +async def run(context): + server = "127.0.0.1" + params = { + 'max_new_tokens': 200, + 'do_sample': True, + 'temperature': 0.5, + 'top_p': 0.9, + 'typical_p': 1, + 'repetition_penalty': 1.05, + 'top_k': 0, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + } + session = random_hash() + + async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: + while content := json.loads(await websocket.recv()): + #Python3.10 syntax, replace with if elif on older + match content["msg"]: + case "send_hash": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7 + })) + case "estimation": + pass + case "send_data": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7, + "data": [ + context, + params['max_new_tokens'], + params['do_sample'], + params['temperature'], + params['top_p'], + params['typical_p'], + params['repetition_penalty'], + params['top_k'], + params['min_length'], + params['no_repeat_ngram_size'], + params['num_beams'], + params['penalty_alpha'], + params['length_penalty'], + params['early_stopping'], + ] + })) + case "process_starts": + pass + case "process_generating" | "process_completed": + yield content["output"]["data"][0] + # You can search for your desired end indicator and + # stop generation by closing the websocket here + if (content["msg"] == "process_completed"): + break + +prompt = "What I would like to say is the following: " + +async def get_result(): + async for response in run(prompt): + # Print intermediate steps + print(response) + + # Print final result + print(response) + +asyncio.run(get_result()) \ No newline at end of file From e91f4bc25a4838137a39817c7f4154bc69de4069 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 08:45:49 -0300 Subject: [PATCH 11/35] Add RWKV tokenizer --- modules/RWKV.py | 20 ++++++++++++++++++++ modules/models.py | 5 +++-- modules/text_generation.py | 24 +++++++++++------------- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 46d8ff5f..2ef86cae 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -2,6 +2,7 @@ import os from pathlib import Path import numpy as np +from tokenizers import Tokenizer import modules.shared as shared @@ -43,3 +44,22 @@ class RWKVModel: ) return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + +class RWKVTokenizer: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path): + tokenizer_path = path / "20B_tokenizer.json" + tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path)) + + result = self() + result.tokenizer = tokenizer + return result + + def encode(self, prompt): + return self.tokenizer.encode(prompt).ids + + def decode(self, ids): + return self.tokenizer.decode(ids) diff --git a/modules/models.py b/modules/models.py index 40feb8b3..16ce6eb1 100644 --- a/modules/models.py +++ b/modules/models.py @@ -79,11 +79,12 @@ def load_model(model_name): # RMKV model (not on HuggingFace) elif shared.is_RWKV: - from modules.RWKV import RWKVModel + from modules.RWKV import RWKVModel, RWKVTokenizer model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + tokenizer = RWKVTokenizer.from_pretrained(Path('models')) - return model, None + return model, tokenizer # Custom else: diff --git a/modules/text_generation.py b/modules/text_generation.py index e6207b56..e1ee5294 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -21,21 +21,19 @@ def get_max_prompt_length(tokens): return max_length def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - - # These models do not have explicit tokenizers for now, so - # we return an estimate for the number of tokens if shared.is_RWKV: - return np.zeros((1, len(prompt)//4)) - - input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) - if shared.args.cpu: - return input_ids - elif shared.args.flexgen: - return input_ids.numpy() - elif shared.args.deepspeed: - return input_ids.to(device=local_rank) + input_ids = shared.tokenizer.encode(str(prompt)) + input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: - return input_ids.cuda() + input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) + if shared.args.cpu: + return input_ids + elif shared.args.flexgen: + return input_ids.numpy() + elif shared.args.deepspeed: + return input_ids.to(device=local_rank) + else: + return input_ids.cuda() def decode(output_ids): reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) From 2de9f122cd7b71fcef91fe27718eb09085be6015 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 09:34:49 -0300 Subject: [PATCH 12/35] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f6c03915..8a4031ad 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * Advanced chat features (send images, get audio responses with TTS). * Stream the text output in real time. * Load parameter presets from text files. -* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows). +* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows). * Split large models across your GPU(s), CPU, and disk. * CPU mode. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). From bf56b6c1fbaba0e437b8846b10ddb886d13c0114 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 10:57:45 -0300 Subject: [PATCH 13/35] Load settings.json without the need for --settings settings.json This is for setting UI defaults --- README.md | 2 +- modules/shared.py | 2 +- server.py | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f6c03915..602e6a2c 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Optionally, you can use the following command-line flags: | `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | | `--rwkv-strategy RWKV_STRATEGY` | The strategy to use while loading RWKV models. Examples: `"cpu fp32"`, `"cuda fp16"`, `"cuda fp16 *30 -> cpu fp32"`. | | `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.| -| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| +| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--listen` | Make the web UI reachable from your local network.| | `--listen-port LISTEN_PORT` | The listening port that the server will use. | diff --git a/modules/shared.py b/modules/shared.py index 29276fd3..e1d3765b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory t parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') -parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') +parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') diff --git a/server.py b/server.py index ed46224e..9f584ba3 100644 --- a/server.py +++ b/server.py @@ -22,8 +22,14 @@ if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n') # Loading custom settings +settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): - new_settings = json.loads(open(Path(shared.args.settings), 'r').read()) + settings_file = Path(shared.args.settings) +elif Path('settings.json').exists(): + settings_file = Path('settings.json') +if settings_file is not None: + print(f"Loading settings from {settings_file}...") + new_settings = json.loads(open(settings_file, 'r').read()) for item in new_settings: shared.settings[item] = new_settings[item] From 5bed607b773886b0c5ade204f63f0e5cdb3e502e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:25:48 -0300 Subject: [PATCH 14/35] Increase repetition frequency/penalty for RWKV --- modules/RWKV.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 2ef86cae..93b11678 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -33,7 +33,7 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.25, alpha_presence=0.25, token_ban=[0], token_stop=[], callback=None): + def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.7, alpha_presence=0.7, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( temperature = temperature, top_p = top_p, From d88b7836c64e138eb32d0e5797247358b6fe8ae1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:58:30 -0300 Subject: [PATCH 15/35] Fix bug in multigpu setups --- modules/text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index e1ee5294..caa77df9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -194,7 +194,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) if not shared.args.flexgen: - if output[-1] == n: + if int(output[-1]) == int(n): break input_ids = torch.reshape(output, (1, output.shape[0])) else: From 24c4c2039176d8480613f13d997ebfb95677aeb2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:23:29 -0300 Subject: [PATCH 16/35] Fix bug in multigpu setups (attempt #2) --- modules/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index caa77df9..e2d85514 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -112,7 +112,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" - n = shared.tokenizer.eos_token_id if eos_token is None else encode(eos_token)[0][-1] + n = torch.tensor(shared.tokenizer.eos_token_id) if eos_token is None else encode(eos_token)[0][-1] if stopping_string is not None: # The stopping_criteria code below was copied from # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py @@ -194,7 +194,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) if not shared.args.flexgen: - if int(output[-1]) == int(n): + if output[-1].to("cpu") == n.to("cpu"): break input_ids = torch.reshape(output, (1, output.shape[0])) else: From 09a7c36e1b7dc64bb9e160c6540b774c96f1598a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:36:35 -0300 Subject: [PATCH 17/35] Minor improvement while running custom models --- modules/text_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/text_generation.py b/modules/text_generation.py index e2d85514..f585c013 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -97,6 +97,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"Output generated in {(t1-t0):.2f} seconds.") yield formatted_outputs(reply, shared.model_name) else: + yield formatted_outputs(question, shared.model_name) for i in tqdm(range(max_new_tokens//8+1)): clear_torch_cache() reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) From 20bd645f6a355c39b1907fe213fa44e38a83dc8c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:58:18 -0300 Subject: [PATCH 18/35] Fix bug in multigpu setups (attempt 3) --- modules/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index f585c013..0e8aec51 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -113,7 +113,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" - n = torch.tensor(shared.tokenizer.eos_token_id) if eos_token is None else encode(eos_token)[0][-1] + n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) if stopping_string is not None: # The stopping_criteria code below was copied from # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py @@ -195,7 +195,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) if not shared.args.flexgen: - if output[-1].to("cpu") == n.to("cpu"): + if output[-1] == n: break input_ids = torch.reshape(output, (1, output.shape[0])) else: From 6904a507c661c9d941d6ff643b16589c6f670905 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 16:29:43 -0300 Subject: [PATCH 19/35] Change some parameters --- modules/RWKV.py | 2 +- presets/Naive.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 93b11678..9a806a00 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -33,7 +33,7 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.7, alpha_presence=0.7, token_ban=[0], token_stop=[], callback=None): + def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( temperature = temperature, top_p = top_p, diff --git a/presets/Naive.txt b/presets/Naive.txt index c6965983..f3114a50 100644 --- a/presets/Naive.txt +++ b/presets/Naive.txt @@ -1,3 +1,3 @@ do_sample=True -top_p=0.95 -temperature=0.8 +top_p=0.85 +temperature=1 From 91823e1ed1c6c919179a5336cc05f9b0d655315a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 16:48:31 -0300 Subject: [PATCH 20/35] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index efa93d30..f2e8e61a 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,8 @@ Models should be placed under `models/model-name`. For instance, `models/gpt-j-6 * [Pythia](https://huggingface.co/models?search=eleutherai/pythia) * [OPT](https://huggingface.co/models?search=facebook/opt) * [GALACTICA](https://huggingface.co/models?search=facebook/galactica) -* [\*-Erebus](https://huggingface.co/models?search=erebus) -* [Pygmalion](https://huggingface.co/models?search=pygmalion) +* [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW) +* [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW) You can automatically download a model from HF using the script `download-model.py`: From 49ae183ac9662f0062ceb5193d61e6bf1373688f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:28:53 -0300 Subject: [PATCH 21/35] Move new extension to a separate file --- .gitignore | 1 + .../outputs/outputs-will-be-saved-here.txt | 0 extensions/elevenlabs/requirements.txt | 6 + extensions/elevenlabs/script.py | 120 ++++++++++++++++++ extensions/silero_tts/script.py | 102 ++++++--------- 5 files changed, 165 insertions(+), 64 deletions(-) create mode 100644 extensions/elevenlabs/outputs/outputs-will-be-saved-here.txt create mode 100644 extensions/elevenlabs/requirements.txt create mode 100644 extensions/elevenlabs/script.py diff --git a/.gitignore b/.gitignore index 6f4c5ba3..b37a7601 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ cache/* characters/* extensions/silero_tts/outputs/* +extensions/elevenlabs/outputs/* logs/* models/* softprompts/* diff --git a/extensions/elevenlabs/outputs/outputs-will-be-saved-here.txt b/extensions/elevenlabs/outputs/outputs-will-be-saved-here.txt new file mode 100644 index 00000000..e69de29b diff --git a/extensions/elevenlabs/requirements.txt b/extensions/elevenlabs/requirements.txt new file mode 100644 index 00000000..f2f0bff5 --- /dev/null +++ b/extensions/elevenlabs/requirements.txt @@ -0,0 +1,6 @@ +ipython +omegaconf +pydub +PyYAML +torch +torchaudio diff --git a/extensions/elevenlabs/script.py b/extensions/elevenlabs/script.py new file mode 100644 index 00000000..e088ec8e --- /dev/null +++ b/extensions/elevenlabs/script.py @@ -0,0 +1,120 @@ +import asyncio +from pathlib import Path + +import gradio as gr +import torch +import io +import json +import os + +import requests + +from elevenlabslib.helpers import * +from elevenlabslib import * + +params = { + 'activate': True, + 'api_key': '12345', + 'selected_voice': 'None', +} +initial_voice = ['None'] +wav_idx = 0 +user = ElevenLabsUser(params['api_key']) +user_info = None + + +"Check if the API is valid and refresh the UI accordingly." +def check_valid_api(): + + global user, user_info, params + + user = ElevenLabsUser(params['api_key']) + user_info = user._get_subscription_data() + print('checking api') + if params['activate'] == False: + return gr.update(value='Disconnected') + elif user_info is None: + print('Incorrect API Key') + return gr.update(value='Disconnected') + else: + print('Got an API Key!') + return gr.update(value='Connected') + +"Once the API is verified, get the available voices and update the dropdown list" +def refresh_voices(): + + global user, user_info + + your_voices = [None] + if user_info is not None: + for voice in user.get_available_voices(): + your_voices.append(voice.initialName) + return gr.Dropdown.update(choices=your_voices) + else: + return + +def remove_surrounded_chars(string): + new_string = "" + in_star = False + for char in string: + if char == '*': + in_star = not in_star + elif not in_star: + new_string += char + return new_string + +def input_modifier(string): + """ + This function is applied to your text inputs before + they are fed into the model. + """ + + return string + +def output_modifier(string): + """ + This function is applied to the model outputs. + """ + global params, wav_idx, user, user_info + + if params['activate'] == False: + return string + elif user_info == None: + return string + + string = remove_surrounded_chars(string) + string = string.replace('"', '') + string = string.replace('“', '') + string = string.replace('\n', ' ') + string = string.strip() + + if string == '': + string = 'empty reply, try regenerating' + + output_file = Path('extensions/elevenlabs_tts/outputs/{}.wav'.format(wav_idx)) + voice = user.get_voices_by_name(params['selected_voice'])[0] + audio_data = voice.generate_audio_bytes(string) + save_bytes_to_path("extensions/elevenlabs_tts/outputs/{}.wav".format(wav_idx), audio_data) + + + string = f'' + wav_idx += 1 + return string + + +def ui(): + # Gradio elements + with gr.Row(): + activate = gr.Checkbox(value=params['activate'], label='Activate TTS') + connection_status = gr.Textbox(value='Disconnected', label='Connection Status') + voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice') + with gr.Row(): + api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') + connect = gr.Button(value='Connect') + # Event functions to update the parameters in the backend + activate.change(lambda x: params.update({'activate': x}), activate, None) + voice.change(lambda x: params.update({'selected_voice': x}), voice, None) + api_key.change(lambda x: params.update({'api_key': x}), api_key, None) + connect.click(check_valid_api, [], connection_status) + connect.click(refresh_voices, [], voice) + diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index e088ec8e..f697d0e2 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -3,55 +3,26 @@ from pathlib import Path import gradio as gr import torch -import io -import json -import os -import requests - -from elevenlabslib.helpers import * -from elevenlabslib import * +torch._C._jit_set_profiling_mode(False) params = { 'activate': True, - 'api_key': '12345', - 'selected_voice': 'None', + 'speaker': 'en_56', + 'language': 'en', + 'model_id': 'v3_en', + 'sample_rate': 48000, + 'device': 'cpu', } -initial_voice = ['None'] +current_params = params.copy() +voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] wav_idx = 0 -user = ElevenLabsUser(params['api_key']) -user_info = None - -"Check if the API is valid and refresh the UI accordingly." -def check_valid_api(): - - global user, user_info, params - - user = ElevenLabsUser(params['api_key']) - user_info = user._get_subscription_data() - print('checking api') - if params['activate'] == False: - return gr.update(value='Disconnected') - elif user_info is None: - print('Incorrect API Key') - return gr.update(value='Disconnected') - else: - print('Got an API Key!') - return gr.update(value='Connected') - -"Once the API is verified, get the available voices and update the dropdown list" -def refresh_voices(): - - global user, user_info - - your_voices = [None] - if user_info is not None: - for voice in user.get_available_voices(): - your_voices.append(voice.initialName) - return gr.Dropdown.update(choices=your_voices) - else: - return +def load_model(): + model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) + model.to(params['device']) + return model +model = load_model() def remove_surrounded_chars(string): new_string = "" @@ -75,12 +46,17 @@ def output_modifier(string): """ This function is applied to the model outputs. """ - global params, wav_idx, user, user_info - + + global wav_idx, model, current_params + + for i in params: + if params[i] != current_params[i]: + model = load_model() + current_params = params.copy() + break + if params['activate'] == False: return string - elif user_info == None: - return string string = remove_surrounded_chars(string) string = string.replace('"', '') @@ -90,31 +66,29 @@ def output_modifier(string): if string == '': string = 'empty reply, try regenerating' - - output_file = Path('extensions/elevenlabs_tts/outputs/{}.wav'.format(wav_idx)) - voice = user.get_voices_by_name(params['selected_voice'])[0] - audio_data = voice.generate_audio_bytes(string) - save_bytes_to_path("extensions/elevenlabs_tts/outputs/{}.wav".format(wav_idx), audio_data) + output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') + audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) string = f'' wav_idx += 1 + return string +def bot_prefix_modifier(string): + """ + This function is only applied in chat mode. It modifies + the prefix text for the Bot and can be used to bias its + behavior. + """ + + return string def ui(): # Gradio elements - with gr.Row(): - activate = gr.Checkbox(value=params['activate'], label='Activate TTS') - connection_status = gr.Textbox(value='Disconnected', label='Connection Status') - voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice') - with gr.Row(): - api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') - connect = gr.Button(value='Connect') - # Event functions to update the parameters in the backend - activate.change(lambda x: params.update({'activate': x}), activate, None) - voice.change(lambda x: params.update({'selected_voice': x}), voice, None) - api_key.change(lambda x: params.update({'api_key': x}), api_key, None) - connect.click(check_valid_api, [], connection_status) - connect.click(refresh_voices, [], voice) + activate = gr.Checkbox(value=params['activate'], label='Activate TTS') + voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') + # Event functions to update the parameters in the backend + activate.change(lambda x: params.update({"activate": x}), activate, None) + voice.change(lambda x: params.update({"speaker": x}), voice, None) From 944fdc03b231492959d633088b8ca0b7ac6b3200 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:38:36 -0300 Subject: [PATCH 22/35] Rename the folder --- .gitignore | 2 +- extensions/elevenlabs/requirements.txt | 6 ------ .../outputs/outputs-will-be-saved-here.txt | 0 extensions/elevenlabs_tts/requirements.txt | 3 +++ extensions/{elevenlabs => elevenlabs_tts}/script.py | 10 ++++------ 5 files changed, 8 insertions(+), 13 deletions(-) delete mode 100644 extensions/elevenlabs/requirements.txt rename extensions/{elevenlabs => elevenlabs_tts}/outputs/outputs-will-be-saved-here.txt (100%) create mode 100644 extensions/elevenlabs_tts/requirements.txt rename extensions/{elevenlabs => elevenlabs_tts}/script.py (99%) diff --git a/.gitignore b/.gitignore index b37a7601..1b7f0fb8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ cache/* characters/* extensions/silero_tts/outputs/* -extensions/elevenlabs/outputs/* +extensions/elevenlabs_tts/outputs/* logs/* models/* softprompts/* diff --git a/extensions/elevenlabs/requirements.txt b/extensions/elevenlabs/requirements.txt deleted file mode 100644 index f2f0bff5..00000000 --- a/extensions/elevenlabs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -ipython -omegaconf -pydub -PyYAML -torch -torchaudio diff --git a/extensions/elevenlabs/outputs/outputs-will-be-saved-here.txt b/extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt similarity index 100% rename from extensions/elevenlabs/outputs/outputs-will-be-saved-here.txt rename to extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt diff --git a/extensions/elevenlabs_tts/requirements.txt b/extensions/elevenlabs_tts/requirements.txt new file mode 100644 index 00000000..8ec07a8a --- /dev/null +++ b/extensions/elevenlabs_tts/requirements.txt @@ -0,0 +1,3 @@ +elevenlabslib +soundfile +sounddevice diff --git a/extensions/elevenlabs/script.py b/extensions/elevenlabs_tts/script.py similarity index 99% rename from extensions/elevenlabs/script.py rename to extensions/elevenlabs_tts/script.py index e088ec8e..ae4fe019 100644 --- a/extensions/elevenlabs/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -1,16 +1,14 @@ import asyncio -from pathlib import Path - -import gradio as gr -import torch import io import json import os +from pathlib import Path +import gradio as gr import requests - -from elevenlabslib.helpers import * +import torch from elevenlabslib import * +from elevenlabslib.helpers import * params = { 'activate': True, From eebec650756ea80ec601a93a6beda3d9793136a7 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:46:46 -0300 Subject: [PATCH 23/35] Improve readability --- extensions/elevenlabs_tts/script.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index ae4fe019..24741837 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -15,13 +15,14 @@ params = { 'api_key': '12345', 'selected_voice': 'None', } + initial_voice = ['None'] wav_idx = 0 user = ElevenLabsUser(params['api_key']) user_info = None -"Check if the API is valid and refresh the UI accordingly." +# Check if the API is valid and refresh the UI accordingly. def check_valid_api(): global user, user_info, params @@ -38,7 +39,7 @@ def check_valid_api(): print('Got an API Key!') return gr.update(value='Connected') -"Once the API is verified, get the available voices and update the dropdown list" +# Once the API is verified, get the available voices and update the dropdown list def refresh_voices(): global user, user_info @@ -73,6 +74,7 @@ def output_modifier(string): """ This function is applied to the model outputs. """ + global params, wav_idx, user, user_info if params['activate'] == False: @@ -89,18 +91,17 @@ def output_modifier(string): if string == '': string = 'empty reply, try regenerating' - output_file = Path('extensions/elevenlabs_tts/outputs/{}.wav'.format(wav_idx)) + output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx)) voice = user.get_voices_by_name(params['selected_voice'])[0] audio_data = voice.generate_audio_bytes(string) - save_bytes_to_path("extensions/elevenlabs_tts/outputs/{}.wav".format(wav_idx), audio_data) - + save_bytes_to_path(Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'), audio_data) string = f'' wav_idx += 1 return string - def ui(): + # Gradio elements with gr.Row(): activate = gr.Checkbox(value=params['activate'], label='Activate TTS') @@ -109,10 +110,10 @@ def ui(): with gr.Row(): api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') connect = gr.Button(value='Connect') + # Event functions to update the parameters in the backend activate.change(lambda x: params.update({'activate': x}), activate, None) voice.change(lambda x: params.update({'selected_voice': x}), voice, None) api_key.change(lambda x: params.update({'api_key': x}), api_key, None) connect.click(check_valid_api, [], connection_status) connect.click(refresh_voices, [], voice) - From 8b882c132a42de35ad2f62536a035cc7ce017e34 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:52:26 -0300 Subject: [PATCH 24/35] tabs -> spaces --- api-example-stream.py | 130 +++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/api-example-stream.py b/api-example-stream.py index 0d93b4b6..b7846ab4 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -5,77 +5,77 @@ import json import asyncio def random_hash(): - letters = string.ascii_lowercase + string.digits - return ''.join(random.choice(letters) for i in range(9)) + letters = string.ascii_lowercase + string.digits + return ''.join(random.choice(letters) for i in range(9)) async def run(context): - server = "127.0.0.1" - params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.5, - 'top_p': 0.9, - 'typical_p': 1, - 'repetition_penalty': 1.05, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - } - session = random_hash() + server = "127.0.0.1" + params = { + 'max_new_tokens': 200, + 'do_sample': True, + 'temperature': 0.5, + 'top_p': 0.9, + 'typical_p': 1, + 'repetition_penalty': 1.05, + 'top_k': 0, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + } + session = random_hash() - async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: - while content := json.loads(await websocket.recv()): - #Python3.10 syntax, replace with if elif on older - match content["msg"]: - case "send_hash": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": 7 - })) - case "estimation": - pass - case "send_data": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": 7, - "data": [ - context, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - ] - })) - case "process_starts": - pass - case "process_generating" | "process_completed": - yield content["output"]["data"][0] - # You can search for your desired end indicator and - # stop generation by closing the websocket here - if (content["msg"] == "process_completed"): - break + async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: + while content := json.loads(await websocket.recv()): + #Python3.10 syntax, replace with if elif on older + match content["msg"]: + case "send_hash": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7 + })) + case "estimation": + pass + case "send_data": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7, + "data": [ + context, + params['max_new_tokens'], + params['do_sample'], + params['temperature'], + params['top_p'], + params['typical_p'], + params['repetition_penalty'], + params['top_k'], + params['min_length'], + params['no_repeat_ngram_size'], + params['num_beams'], + params['penalty_alpha'], + params['length_penalty'], + params['early_stopping'], + ] + })) + case "process_starts": + pass + case "process_generating" | "process_completed": + yield content["output"]["data"][0] + # You can search for your desired end indicator and + # stop generation by closing the websocket here + if (content["msg"] == "process_completed"): + break prompt = "What I would like to say is the following: " async def get_result(): - async for response in run(prompt): - # Print intermediate steps - print(response) + async for response in run(prompt): + # Print intermediate steps + print(response) - # Print final result - print(response) + # Print final result + print(response) -asyncio.run(get_result()) \ No newline at end of file +asyncio.run(get_result()) From 153dfeb4dde562ef0bad6743e832e76f28dc9200 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:12:54 -0300 Subject: [PATCH 25/35] Add --rwkv-cuda-on parameter, bump rwkv version --- modules/RWKV.py | 2 +- modules/shared.py | 3 ++- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 9a806a00..acc97044 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -9,7 +9,7 @@ import modules.shared as shared np.set_printoptions(precision=4, suppress=True, linewidth=200) os.environ['RWKV_JIT_ON'] = '1' -os.environ["RWKV_CUDA_ON"] = '0' # '1' : use CUDA kernel for seq mode (much faster) +os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS diff --git a/modules/shared.py b/modules/shared.py index e1d3765b..b609045c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,7 +81,8 @@ parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, defaul parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') -parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".') +parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".') +parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') diff --git a/requirements.txt b/requirements.txt index 2051dc0b..3a2ac25d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy -rwkv==0.0.7 +rwkv==0.0.8 safetensors==0.2.8 sentencepiece git+https://github.com/oobabooga/transformers@llama_push From 18ccfcd7fe0607ff6d3fab7b23d4a167a1c8e6ea Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:15:55 -0300 Subject: [PATCH 26/35] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f2e8e61a..7a3dc065 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * [Get responses via API](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py). +* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * Supports softprompts. * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). * [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab). From d0e87805557ac3c482bffc89ad44fec1d2b634a9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:17:59 -0300 Subject: [PATCH 27/35] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a3dc065..eef4fc38 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,8 @@ Optionally, you can use the following command-line flags: | `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. | | `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. | | `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | -| `--rwkv-strategy RWKV_STRATEGY` | The strategy to use while loading RWKV models. Examples: `"cpu fp32"`, `"cuda fp16"`, `"cuda fp16 *30 -> cpu fp32"`. | +| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | +| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | From 8f4a197c055381d744cd4019e498faf985df85c6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:34:36 -0300 Subject: [PATCH 28/35] Add credits --- api-example-stream.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api-example-stream.py b/api-example-stream.py index b7846ab4..78ff0bd7 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -1,3 +1,10 @@ +''' + +Contributed by SagsMug. Thank you SagsMug. +https://github.com/oobabooga/text-generation-webui/pull/175 + +''' + import string import random import websockets From b4bfd87319838ca0e7b450861830b60f71588720 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:55:01 -0300 Subject: [PATCH 29/35] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eef4fc38..9efacb7c 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * CPU mode. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). -* [Get responses via API](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py). +* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * Supports softprompts. * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). From 827ae51f7240e3fbdb45b43a8bc5880a16565478 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 00:23:36 -0300 Subject: [PATCH 30/35] Sort the imports --- api-example-stream.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api-example-stream.py b/api-example-stream.py index 78ff0bd7..a5ed4202 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -5,11 +5,13 @@ https://github.com/oobabooga/text-generation-webui/pull/175 ''' -import string -import random -import websockets -import json import asyncio +import json +import random +import string + +import websockets + def random_hash(): letters = string.ascii_lowercase + string.digits From 8660227e1b8a926bd68b552f15190f94af785036 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 17:24:28 -0300 Subject: [PATCH 31/35] Add top_k to RWKV --- modules/RWKV.py | 3 ++- modules/text_generation.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index acc97044..739a7e73 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -33,10 +33,11 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): + def generate(self, context, token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( temperature = temperature, top_p = top_p, + top_k = top_k, alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3) alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3) token_ban = token_ban, # ban the generation of some tokens diff --git a/modules/text_generation.py b/modules/text_generation.py index 0e8aec51..0807a41e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -92,7 +92,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # separately and terminate the function call earlier if shared.is_RWKV: if shared.args.no_stream: - reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p) + reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) t1 = time.time() print(f"Output generated in {(t1-t0):.2f} seconds.") yield formatted_outputs(reply, shared.model_name) @@ -100,7 +100,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(question, shared.model_name) for i in tqdm(range(max_new_tokens//8+1)): clear_torch_cache() - reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) + reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) question = reply return diff --git a/requirements.txt b/requirements.txt index 3a2ac25d..47c56a45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy -rwkv==0.0.8 +rwkv==0.1.0 safetensors==0.2.8 sentencepiece git+https://github.com/oobabooga/transformers@llama_push From 19a34941ed9daeeb93ab951645b61957e4df5376 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 18:17:56 -0300 Subject: [PATCH 32/35] Add proper streaming to RWKV --- modules/RWKV.py | 46 +++++++++++++++++++++++++++++++++++++- modules/text_generation.py | 14 ++++++------ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 739a7e73..b226a195 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -1,5 +1,7 @@ import os from pathlib import Path +from queue import Queue +from threading import Thread import numpy as np from tokenizers import Tokenizer @@ -33,7 +35,7 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context, token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( temperature = temperature, top_p = top_p, @@ -46,6 +48,13 @@ class RWKVModel: return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + def generate_with_streaming(self, **kwargs): + iterable = Iteratorize(self.generate, kwargs, callback=None) + reply = kwargs['context'] + for token in iterable: + reply += token + yield reply + class RWKVTokenizer: def __init__(self): pass @@ -64,3 +73,38 @@ class RWKVTokenizer: def decode(self, ids): return self.tokenizer.decode(ids) + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc=func + self.c_callback=callback + self.q = Queue(maxsize=1) + self.sentinel = object() + self.kwargs = kwargs + + def _callback(val): + self.q.put(val) + + def gentask(): + ret = self.mfunc(callback=_callback, **self.kwargs) + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + Thread(target=gentask).start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True,None) + if obj is self.sentinel: + raise StopIteration + else: + return obj diff --git a/modules/text_generation.py b/modules/text_generation.py index 0807a41e..9adc2fdd 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -92,17 +92,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # separately and terminate the function call earlier if shared.is_RWKV: if shared.args.no_stream: - reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds.") + reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) else: yield formatted_outputs(question, shared.model_name) - for i in tqdm(range(max_new_tokens//8+1)): - clear_torch_cache() - reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p, top_k=top_k) + # RWKV has proper streaming, which is very nice. + # No need to generate 8 tokens at a time. + for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): yield formatted_outputs(reply, shared.model_name) - question = reply + + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds.") return original_question = question From 44e6d821859f31e6fa504899bf2bf42cbb28c189 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 22:56:15 -0300 Subject: [PATCH 33/35] Remove unused imports --- extensions/elevenlabs_tts/script.py | 6 ------ extensions/silero_tts/script.py | 3 +-- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index 24741837..90d61efc 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -1,12 +1,6 @@ -import asyncio -import io -import json -import os from pathlib import Path import gradio as gr -import requests -import torch from elevenlabslib import * from elevenlabslib.helpers import * diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index f697d0e2..050392d6 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -1,4 +1,3 @@ -import asyncio from pathlib import Path import gradio as gr @@ -68,7 +67,7 @@ def output_modifier(string): string = 'empty reply, try regenerating' output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') - audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) string = f'' wav_idx += 1 From 8e89bc596b2fee8a7e4bcc49d25f900515beafb5 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 23:15:46 -0300 Subject: [PATCH 34/35] Fix encode() for RWKV --- modules/text_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/text_generation.py b/modules/text_generation.py index 9adc2fdd..4af53273 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -24,6 +24,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): if shared.is_RWKV: input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) + return input_ids else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) if shared.args.cpu: From c09f416adbd4a4df6c36f49c4ce54b749ad5d1e7 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Mar 2023 23:17:13 -0300 Subject: [PATCH 35/35] Change the Naive preset (again) --- presets/Naive.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/presets/Naive.txt b/presets/Naive.txt index f3114a50..aa8c0582 100644 --- a/presets/Naive.txt +++ b/presets/Naive.txt @@ -1,3 +1,4 @@ do_sample=True +temperature=0.7 top_p=0.85 -temperature=1 +top_k=50