diff --git a/.gitignore b/.gitignore
index 6f4c5ba3..1b7f0fb8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
cache/*
characters/*
extensions/silero_tts/outputs/*
+extensions/elevenlabs_tts/outputs/*
logs/*
models/*
softprompts/*
diff --git a/README.md b/README.md
index f6c03915..9efacb7c 100644
--- a/README.md
+++ b/README.md
@@ -21,12 +21,13 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* Advanced chat features (send images, get audio responses with TTS).
* Stream the text output in real time.
* Load parameter presets from text files.
-* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows).
+* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows).
* Split large models across your GPU(s), CPU, and disk.
* CPU mode.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
-* [Get responses via API](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py).
+* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
+* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
* [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab).
@@ -82,8 +83,8 @@ Models should be placed under `models/model-name`. For instance, `models/gpt-j-6
* [Pythia](https://huggingface.co/models?search=eleutherai/pythia)
* [OPT](https://huggingface.co/models?search=facebook/opt)
* [GALACTICA](https://huggingface.co/models?search=facebook/galactica)
-* [\*-Erebus](https://huggingface.co/models?search=erebus)
-* [Pygmalion](https://huggingface.co/models?search=pygmalion)
+* [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW)
+* [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW)
You can automatically download a model from HF using the script `download-model.py`:
@@ -149,9 +150,10 @@ Optionally, you can use the following command-line flags:
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
-| `--rwkv-strategy RWKV_STRATEGY` | The strategy to use while loading RWKV models. Examples: `"cpu fp32"`, `"cuda fp16"`, `"cuda fp16 *30 -> cpu fp32"`. |
+| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
+| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
-| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
+| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.|
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
diff --git a/api-example-stream.py b/api-example-stream.py
new file mode 100644
index 00000000..a5ed4202
--- /dev/null
+++ b/api-example-stream.py
@@ -0,0 +1,90 @@
+'''
+
+Contributed by SagsMug. Thank you SagsMug.
+https://github.com/oobabooga/text-generation-webui/pull/175
+
+'''
+
+import asyncio
+import json
+import random
+import string
+
+import websockets
+
+
+def random_hash():
+ letters = string.ascii_lowercase + string.digits
+ return ''.join(random.choice(letters) for i in range(9))
+
+async def run(context):
+ server = "127.0.0.1"
+ params = {
+ 'max_new_tokens': 200,
+ 'do_sample': True,
+ 'temperature': 0.5,
+ 'top_p': 0.9,
+ 'typical_p': 1,
+ 'repetition_penalty': 1.05,
+ 'top_k': 0,
+ 'min_length': 0,
+ 'no_repeat_ngram_size': 0,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'length_penalty': 1,
+ 'early_stopping': False,
+ }
+ session = random_hash()
+
+ async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
+ while content := json.loads(await websocket.recv()):
+ #Python3.10 syntax, replace with if elif on older
+ match content["msg"]:
+ case "send_hash":
+ await websocket.send(json.dumps({
+ "session_hash": session,
+ "fn_index": 7
+ }))
+ case "estimation":
+ pass
+ case "send_data":
+ await websocket.send(json.dumps({
+ "session_hash": session,
+ "fn_index": 7,
+ "data": [
+ context,
+ params['max_new_tokens'],
+ params['do_sample'],
+ params['temperature'],
+ params['top_p'],
+ params['typical_p'],
+ params['repetition_penalty'],
+ params['top_k'],
+ params['min_length'],
+ params['no_repeat_ngram_size'],
+ params['num_beams'],
+ params['penalty_alpha'],
+ params['length_penalty'],
+ params['early_stopping'],
+ ]
+ }))
+ case "process_starts":
+ pass
+ case "process_generating" | "process_completed":
+ yield content["output"]["data"][0]
+ # You can search for your desired end indicator and
+ # stop generation by closing the websocket here
+ if (content["msg"] == "process_completed"):
+ break
+
+prompt = "What I would like to say is the following: "
+
+async def get_result():
+ async for response in run(prompt):
+ # Print intermediate steps
+ print(response)
+
+ # Print final result
+ print(response)
+
+asyncio.run(get_result())
diff --git a/extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt b/extensions/elevenlabs_tts/outputs/outputs-will-be-saved-here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions/elevenlabs_tts/requirements.txt b/extensions/elevenlabs_tts/requirements.txt
new file mode 100644
index 00000000..8ec07a8a
--- /dev/null
+++ b/extensions/elevenlabs_tts/requirements.txt
@@ -0,0 +1,3 @@
+elevenlabslib
+soundfile
+sounddevice
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
new file mode 100644
index 00000000..90d61efc
--- /dev/null
+++ b/extensions/elevenlabs_tts/script.py
@@ -0,0 +1,113 @@
+from pathlib import Path
+
+import gradio as gr
+from elevenlabslib import *
+from elevenlabslib.helpers import *
+
+params = {
+ 'activate': True,
+ 'api_key': '12345',
+ 'selected_voice': 'None',
+}
+
+initial_voice = ['None']
+wav_idx = 0
+user = ElevenLabsUser(params['api_key'])
+user_info = None
+
+
+# Check if the API is valid and refresh the UI accordingly.
+def check_valid_api():
+
+ global user, user_info, params
+
+ user = ElevenLabsUser(params['api_key'])
+ user_info = user._get_subscription_data()
+ print('checking api')
+ if params['activate'] == False:
+ return gr.update(value='Disconnected')
+ elif user_info is None:
+ print('Incorrect API Key')
+ return gr.update(value='Disconnected')
+ else:
+ print('Got an API Key!')
+ return gr.update(value='Connected')
+
+# Once the API is verified, get the available voices and update the dropdown list
+def refresh_voices():
+
+ global user, user_info
+
+ your_voices = [None]
+ if user_info is not None:
+ for voice in user.get_available_voices():
+ your_voices.append(voice.initialName)
+ return gr.Dropdown.update(choices=your_voices)
+ else:
+ return
+
+def remove_surrounded_chars(string):
+ new_string = ""
+ in_star = False
+ for char in string:
+ if char == '*':
+ in_star = not in_star
+ elif not in_star:
+ new_string += char
+ return new_string
+
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+
+ return string
+
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+
+ global params, wav_idx, user, user_info
+
+ if params['activate'] == False:
+ return string
+ elif user_info == None:
+ return string
+
+ string = remove_surrounded_chars(string)
+ string = string.replace('"', '')
+ string = string.replace('“', '')
+ string = string.replace('\n', ' ')
+ string = string.strip()
+
+ if string == '':
+ string = 'empty reply, try regenerating'
+
+ output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
+ voice = user.get_voices_by_name(params['selected_voice'])[0]
+ audio_data = voice.generate_audio_bytes(string)
+ save_bytes_to_path(Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'), audio_data)
+
+ string = f''
+ wav_idx += 1
+ return string
+
+def ui():
+
+ # Gradio elements
+ with gr.Row():
+ activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
+ connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
+ voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
+ with gr.Row():
+ api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
+ connect = gr.Button(value='Connect')
+
+ # Event functions to update the parameters in the backend
+ activate.change(lambda x: params.update({'activate': x}), activate, None)
+ voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
+ api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
+ connect.click(check_valid_api, [], connection_status)
+ connect.click(refresh_voices, [], voice)
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
index 03319dbf..53bd554c 100644
--- a/extensions/silero_tts/script.py
+++ b/extensions/silero_tts/script.py
@@ -1,4 +1,3 @@
-import asyncio
from pathlib import Path
import gradio as gr
@@ -94,7 +93,7 @@ def output_modifier(string):
string =''+prosody+xmlesc(string)+''
output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
- audio = model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
+ model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
string = f''
#reset if too many wavs. set max to -1 for unlimited.
diff --git a/modules/LLaMA.py b/modules/LLaMA.py
deleted file mode 100644
index 3781ccf5..00000000
--- a/modules/LLaMA.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the GNU General Public License version 3.
-
-import json
-import os
-import sys
-import time
-from pathlib import Path
-from typing import Tuple
-
-import fire
-import torch
-from fairscale.nn.model_parallel.initialize import initialize_model_parallel
-from llama import LLaMA, ModelArgs, Tokenizer, Transformer
-
-os.environ['RANK'] = '0'
-os.environ['WORLD_SIZE'] = '1'
-os.environ['MP'] = '1'
-os.environ['MASTER_ADDR'] = '127.0.0.1'
-os.environ['MASTER_PORT'] = '2223'
-
-def setup_model_parallel() -> Tuple[int, int]:
- local_rank = int(os.environ.get("LOCAL_RANK", -1))
- world_size = int(os.environ.get("WORLD_SIZE", -1))
-
- torch.distributed.init_process_group("gloo")
- initialize_model_parallel(world_size)
- torch.cuda.set_device(local_rank)
-
- # seed must be the same in all processes
- torch.manual_seed(1)
- return local_rank, world_size
-
-def load(
- ckpt_dir: str,
- tokenizer_path: str,
- local_rank: int,
- world_size: int,
- max_seq_len: int,
- max_batch_size: int,
-) -> LLaMA:
- start_time = time.time()
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
- assert world_size == len(
- checkpoints
- ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
- ckpt_path = checkpoints[local_rank]
- print("Loading")
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- with open(Path(ckpt_dir) / "params.json", "r") as f:
- params = json.loads(f.read())
-
- model_args: ModelArgs = ModelArgs(
- max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
- )
- tokenizer = Tokenizer(model_path=tokenizer_path)
- model_args.vocab_size = tokenizer.n_words
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
- model = Transformer(model_args)
- torch.set_default_tensor_type(torch.FloatTensor)
- model.load_state_dict(checkpoint, strict=False)
-
- generator = LLaMA(model, tokenizer)
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
- return generator
-
-
-class LLaMAModel:
- def __init__(self):
- pass
-
- @classmethod
- def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1):
- tokenizer_path = path / "tokenizer.model"
- path = os.path.abspath(path)
- tokenizer_path = os.path.abspath(tokenizer_path)
-
- local_rank, world_size = setup_model_parallel()
- if local_rank > 0:
- sys.stdout = open(os.devnull, "w")
-
- generator = load(
- path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
- )
-
- result = self()
- result.pipeline = generator
- return result
-
- def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
-
- results = self.pipeline.generate(
- [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
- )
-
- return results[0]
diff --git a/modules/RWKV.py b/modules/RWKV.py
index 46d8ff5f..b226a195 100644
--- a/modules/RWKV.py
+++ b/modules/RWKV.py
@@ -1,14 +1,17 @@
import os
from pathlib import Path
+from queue import Queue
+from threading import Thread
import numpy as np
+from tokenizers import Tokenizer
import modules.shared as shared
np.set_printoptions(precision=4, suppress=True, linewidth=200)
os.environ['RWKV_JIT_ON'] = '1'
-os.environ["RWKV_CUDA_ON"] = '0' # '1' : use CUDA kernel for seq mode (much faster)
+os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
@@ -32,10 +35,11 @@ class RWKVModel:
result.pipeline = pipeline
return result
- def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.25, alpha_presence=0.25, token_ban=[0], token_stop=[], callback=None):
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
args = PIPELINE_ARGS(
temperature = temperature,
top_p = top_p,
+ top_k = top_k,
alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
token_ban = token_ban, # ban the generation of some tokens
@@ -43,3 +47,64 @@ class RWKVModel:
)
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
+
+ def generate_with_streaming(self, **kwargs):
+ iterable = Iteratorize(self.generate, kwargs, callback=None)
+ reply = kwargs['context']
+ for token in iterable:
+ reply += token
+ yield reply
+
+class RWKVTokenizer:
+ def __init__(self):
+ pass
+
+ @classmethod
+ def from_pretrained(self, path):
+ tokenizer_path = path / "20B_tokenizer.json"
+ tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
+
+ result = self()
+ result.tokenizer = tokenizer
+ return result
+
+ def encode(self, prompt):
+ return self.tokenizer.encode(prompt).ids
+
+ def decode(self, ids):
+ return self.tokenizer.decode(ids)
+
+class Iteratorize:
+
+ """
+ Transforms a function that takes a callback
+ into a lazy iterator (generator).
+ """
+
+ def __init__(self, func, kwargs={}, callback=None):
+ self.mfunc=func
+ self.c_callback=callback
+ self.q = Queue(maxsize=1)
+ self.sentinel = object()
+ self.kwargs = kwargs
+
+ def _callback(val):
+ self.q.put(val)
+
+ def gentask():
+ ret = self.mfunc(callback=_callback, **self.kwargs)
+ self.q.put(self.sentinel)
+ if self.c_callback:
+ self.c_callback(ret)
+
+ Thread(target=gentask).start()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ obj = self.q.get(True,None)
+ if obj is self.sentinel:
+ raise StopIteration
+ else:
+ return obj
diff --git a/modules/chat.py b/modules/chat.py
index 3b4cbba3..f40f8299 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -51,23 +51,29 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
prompt = ''.join(rows)
return prompt
-def extract_message_from_reply(question, reply, current, other, check, extensions=False):
+def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
next_character_found = False
substring_found = False
- previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)]
- idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)]
- idx = idx[len(previous_idx)-1]
+ asker = name1 if not impersonate else name2
+ replier = name2 if not impersonate else name1
- if extensions:
- reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):]
+ previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", question)]
+ idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", reply)]
+ idx = idx[max(len(previous_idx)-1, 0)]
+
+ if not impersonate:
+ reply = reply[idx + 1 + len(apply_extensions(f"{replier}:", "bot_prefix")):]
else:
- reply = reply[idx + 1 + len(f"{current}:"):]
+ reply = reply[idx + 1 + len(f"{replier}:"):]
if check:
- reply = reply.split('\n')[0].strip()
+ lines = reply.split('\n')
+ reply = lines[0].strip()
+ if len(lines) > 1:
+ next_character_found = True
else:
- idx = reply.find(f"\n{other}:")
+ idx = reply.find(f"\n{asker}:")
if idx != -1:
reply = reply[:idx]
next_character_found = True
@@ -75,7 +81,7 @@ def extract_message_from_reply(question, reply, current, other, check, extension
# Detect if something like "\nYo" is generated just before
# "\nYou:" is completed
- tmp = f"\n{other}:"
+ tmp = f"\n{asker}:"
for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]:
substring_found = True
@@ -89,6 +95,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
+ name1_original = name1
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
@@ -119,8 +126,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply
- reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True)
- visible_reply = apply_extensions(reply, "output")
+ reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check)
+ visible_reply = re.sub("(||{{user}})", name1_original, reply)
+ visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat:
visible_reply = visible_reply.replace('\n', '
')
@@ -139,6 +147,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']
if next_character_found:
break
+
yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
@@ -152,7 +161,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
reply = ''
for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
- reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False)
+ reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
if not substring_found:
yield reply
if next_character_found:
diff --git a/modules/models.py b/modules/models.py
index 904d8ae2..16ce6eb1 100644
--- a/modules/models.py
+++ b/modules/models.py
@@ -39,10 +39,9 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-')
- shared.is_LLaMA = model_name.lower().startswith('llama-')
# Default settings
- if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA):
+ if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@@ -80,20 +79,12 @@ def load_model(model_name):
# RMKV model (not on HuggingFace)
elif shared.is_RWKV:
- from modules.RWKV import RWKVModel
+ from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+ tokenizer = RWKVTokenizer.from_pretrained(Path('models'))
- return model, None
-
- # LLaMA model (not on HuggingFace)
- elif shared.is_LLaMA:
- import modules.LLaMA
- from modules.LLaMA import LLaMAModel
-
- model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
-
- return model, None
+ return model, tokenizer
# Custom
else:
diff --git a/modules/shared.py b/modules/shared.py
index 90adb320..8ad45142 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -6,7 +6,6 @@ model_name = ""
soft_prompt_tensor = None
soft_prompt = False
is_RWKV = False
-is_LLaMA = False
# Chat variables
history = {'internal': [], 'visible': []}
@@ -44,7 +43,6 @@ settings = {
'default': 'NovelAI-Sphinx Moth',
'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive',
- 'llama-*': 'Naive',
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
},
'prompts': {
@@ -84,9 +82,10 @@ parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, defaul
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
-parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".')
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
-parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
diff --git a/modules/text_generation.py b/modules/text_generation.py
index c9f4fc6a..5a715e8e 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -21,21 +21,20 @@ def get_max_prompt_length(tokens):
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
-
- # These models do not have explicit tokenizers for now, so
- # we return an estimate for the number of tokens
- if shared.is_RWKV or shared.is_LLaMA:
- return np.zeros((1, len(prompt)//4))
-
- input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
- if shared.args.cpu:
+ if shared.is_RWKV:
+ input_ids = shared.tokenizer.encode(str(prompt))
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
return input_ids
- elif shared.args.flexgen:
- return input_ids.numpy()
- elif shared.args.deepspeed:
- return input_ids.to(device=local_rank)
else:
- return input_ids.cuda()
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+ if shared.args.cpu:
+ return input_ids
+ elif shared.args.flexgen:
+ return input_ids.numpy()
+ elif shared.args.deepspeed:
+ return input_ids.to(device=local_rank)
+ else:
+ return input_ids.cuda()
def decode(output_ids):
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -81,26 +80,30 @@ def formatted_outputs(reply, model_name):
else:
return reply
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+ clear_torch_cache()
t0 = time.time()
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
- if shared.is_RWKV or shared.is_LLaMA:
+ if shared.is_RWKV:
if shared.args.no_stream:
- reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
- t1 = time.time()
- print(f"Output generated in {(t1-t0):.2f} seconds.")
+ reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name)
else:
- for i in tqdm(range(max_new_tokens//8+1)):
- reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p)
+ yield formatted_outputs(question, shared.model_name)
+ # RWKV has proper streaming, which is very nice.
+ # No need to generate 8 tokens at a time.
+ for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name)
- question = reply
+
+ t1 = time.time()
+ print(f"Output generated in {(t1-t0):.2f} seconds.")
return
original_question = question
@@ -111,8 +114,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
input_ids = encode(question, max_new_tokens)
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
- n = shared.tokenizer.eos_token_id if eos_token is None else encode(eos_token)[0][-1]
-
+ n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
if stopping_string is not None:
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
@@ -149,14 +151,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
f"temperature={temperature}",
f"stop={n}",
]
-
if shared.args.deepspeed:
generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append("max_new_tokens=max_new_tokens")
else:
generate_params.append("max_new_tokens=8")
-
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds")
@@ -184,6 +184,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(original_question, shared.model_name)
shared.still_streaming = True
for i in tqdm(range(max_new_tokens//8+1)):
+ clear_torch_cache()
+
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
diff --git a/presets/Naive.txt b/presets/Naive.txt
index c6965983..aa8c0582 100644
--- a/presets/Naive.txt
+++ b/presets/Naive.txt
@@ -1,3 +1,4 @@
do_sample=True
-top_p=0.95
-temperature=0.8
+temperature=0.7
+top_p=0.85
+top_k=50
diff --git a/requirements.txt b/requirements.txt
index 48ca1e4e..a8a6eada 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,8 @@ bitsandbytes==0.37.0
flexgen==0.1.7
gradio==3.18.0
numpy
-rwkv==0.0.6
+rwkv==0.1.0
safetensors==0.2.8
-git+https://github.com/huggingface/transformers
tensorboard
+sentencepiece
+git+https://github.com/oobabooga/transformers@llama_push
diff --git a/server.py b/server.py
index ed46224e..9f584ba3 100644
--- a/server.py
+++ b/server.py
@@ -22,8 +22,14 @@ if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings
+settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists():
- new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
+ settings_file = Path(shared.args.settings)
+elif Path('settings.json').exists():
+ settings_file = Path('settings.json')
+if settings_file is not None:
+ print(f"Loading settings from {settings_file}...")
+ new_settings = json.loads(open(settings_file, 'r').read())
for item in new_settings:
shared.settings[item] = new_settings[item]