Merge branch 'oobabooga:main' into stt-extension

This commit is contained in:
Elias Vincent Simon 2023-03-12 19:19:43 +01:00 committed by GitHub
commit 3b4145966d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 462 additions and 185 deletions

View File

@ -1,6 +1,6 @@
# Text generation web UI
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion.
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
* If you are running in CPU mode, replace the third command with this one:
* If you are running it in CPU mode, replace the third command with this one:
```
conda install pytorch torchvision torchaudio git -c pytorch
@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags:
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
Pull requests, suggestions, and issue reports are welcome.
Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above.
Before reporting a bug, make sure that you have:
These issues are known:
* 8-bit doesn't work properly on Windows or older GPUs.
* DeepSpeed doesn't work properly on Windows.
For these two, please try commenting on an existing issue instead of creating a new one.
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
## Credits

View File

@ -5,7 +5,9 @@ Example:
python download-model.py facebook/opt-1.3b
'''
import argparse
import base64
import json
import multiprocessing
import re
@ -93,23 +95,28 @@ facebook/opt-1.3b
def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
links = []
classifications = []
has_pytorch = False
has_safetensors = False
while page is not None:
content = requests.get(f"{base}{page}").content
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content
dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)):
fname = dict[i]['path']
is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
if is_text or is_safetensors or is_pytorch:
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')
#page = dict['nextUrl']
page = None
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors:

View File

@ -0,0 +1,18 @@
import gradio as gr
import modules.shared as shared
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name):
if name == 'None':
return ''
else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui():
if not shared.args.chat or shared.args.cai_chat:
choices = ['None'] + list(df['Prompt name'])
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])

View File

@ -1,21 +1,45 @@
import re
import time
from pathlib import Path
import gradio as gr
import torch
import modules.chat as chat
import modules.shared as shared
torch._C._jit_set_profiling_mode(False)
params = {
'activate': True,
'speaker': 'en_56',
'speaker': 'en_5',
'language': 'en',
'model_id': 'v3_en',
'sample_rate': 48000,
'device': 'cpu',
'show_text': False,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
}
current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
wav_idx = 0
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
last_msg_id = 0
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})
def xmlesc(txt):
return txt.translate(table)
def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
@ -33,12 +57,59 @@ def remove_surrounded_chars(string):
new_string += char
return new_string
def remove_tts_from_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
for i, entry in enumerate(shared.history['internal']):
reply = entry[1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = reply
if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']
def toggle_text_in_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
audio_str='\n\n' # The '\n\n' used after </audio>
if shared.args.chat:
audio_str='<br><br>'
if params['show_text']==True:
#for i, entry in enumerate(shared.history['internal']):
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
reply = shared.history['internal'][i][1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
else:
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str
if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
# Remove autoplay from previous chat history
if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
[visible_text, visible_reply] = shared.history['visible'][-1]
vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
shared.history['visible'][-1] = [visible_text, vis_rep_clean]
return string
def output_modifier(string):
@ -46,7 +117,7 @@ def output_modifier(string):
This function is applied to the model outputs.
"""
global wav_idx, model, current_params
global model, current_params
for i in params:
if params[i] != current_params[i]:
@ -57,20 +128,34 @@ def output_modifier(string):
if params['activate'] == False:
return string
orig_string = string
string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('', '')
string = string.replace('\n', ' ')
string = string.strip()
silent_string = False # Used to prevent unnecessary audio file generation
if string == '':
string = 'empty reply, try regenerating'
silent_string = True
output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
pitch = params['voice_pitch']
speed = params['voice_speed']
prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
wav_idx += 1
if not shared.still_streaming and not silent_string:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay_str = ' autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
else:
# Placeholder so text doesn't shift around so much
string = '<audio controls></audio>\n\n'
if params['show_text']:
string += orig_string
return string
@ -85,9 +170,36 @@ def bot_prefix_modifier(string):
def ui():
# Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Accordion("Silero TTS"):
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row():
convert = gr.Button('Permanently replace chat history audio with message text')
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
convert_cancel = gr.Button('Cancel', visible=False)
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [], shared.gradio['display'])
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

View File

@ -1,12 +1,11 @@
import os
from pathlib import Path
from queue import Queue
from threading import Thread
import numpy as np
from tokenizers import Tokenizer
import modules.shared as shared
from modules.callbacks import Iteratorize
np.set_printoptions(precision=4, suppress=True, linewidth=200)
@ -49,11 +48,11 @@ class RWKVModel:
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs):
iterable = Iteratorize(self.generate, kwargs, callback=None)
reply = kwargs['context']
for token in iterable:
reply += token
yield reply
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context']
for token in generator:
reply += token
yield reply
class RWKVTokenizer:
def __init__(self):
@ -73,38 +72,3 @@ class RWKVTokenizer:
def decode(self, ids):
return self.tokenizer.decode(ids)
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
self.q.put(val)
def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
Thread(target=gentask).start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

98
modules/callbacks.py Normal file
View File

@ -0,0 +1,98 @@
import gc
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
clear_torch_cache()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()

View File

@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
tmp = f"\n{asker}:"
for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]:
reply = reply[:-j]
substring_found = True
return reply, next_character_found, substring_found
@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
if not regenerate:
# Display user input and "*is typing...*" imediately
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
# Generate
reply = ''
for i in range(chat_generation_attempts):
@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Display "*is typing...*" imediately
yield '*Is typing...*'
reply = ''
for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -291,7 +299,7 @@ def save_history(timestamp=True):
fname = f"{prefix}persistent.json"
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w') as f:
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}')
@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
shared.history['visible'] = []
if _character != 'None':
shared.character = _character
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
i += 1
if tavern:
outfile_name = f'TavernAI-{outfile_name}'
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
f.write(json_file)
if img is not None:
img = Image.open(io.BytesIO(img))

View File

@ -1,5 +1,6 @@
import json
import os
import sys
import time
import zipfile
from pathlib import Path
@ -41,7 +42,7 @@ def load_model(model_name):
shared.is_RWKV = model_name.lower().startswith('rwkv-')
# Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@ -86,6 +87,12 @@ def load_model(model_name):
return model, tokenizer
# 4-bit LLaMA
elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
from modules.quantized_LLaMA import load_quantized_LLaMA
model = load_quantized_LLaMA(model_name)
# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"

View File

@ -0,0 +1,60 @@
import os
import sys
from pathlib import Path
import accelerate
import torch
import modules.shared as shared
sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa")))
from llama import load_quant
# 4-bit LLaMA
def load_quantized_LLaMA(model_name):
if shared.args.load_in_4bit:
bits = 4
else:
bits = shared.args.gptq_bits
path_to_model = Path(f'models/{model_name}')
pt_model = ''
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{bits}bit.pt'
else:
pt_model = f'{model_name}-{bits}bit.pt'
# Try to find the .pt both in models/ and in the subfolder
pt_path = None
for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
pt_path = path
if not pt_path:
print(f"Could not find {pt_model}, exiting...")
exit()
model = load_quant(path_to_model, os.path.abspath(pt_path), bits)
# Multi-GPU setup
if shared.args.gpu_memory:
max_memory = {}
for i in range(len(shared.args.gpu_memory)):
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map)
# Single GPU
else:
model = model.to(torch.device('cuda:0'))
return model

View File

@ -11,6 +11,7 @@ is_RWKV = False
history = {'internal': [], 'visible': []}
character = 'None'
stop_everything = False
still_streaming = False
# UI elements (buttons, sliders, HTML, etc)
gradio = {}
@ -42,12 +43,12 @@ settings = {
'default': 'NovelAI-Sphinx Moth',
'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive',
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
},
'prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n'
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
}
}
@ -68,6 +69,8 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
@ -90,4 +93,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
args = parser.parse_args()

View File

@ -1,32 +0,0 @@
'''
This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False

View File

@ -5,13 +5,13 @@ import time
import numpy as np
import torch
import transformers
from tqdm import tqdm
import modules.shared as shared
from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens):
@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if shared.is_RWKV:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name)
else:
yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds.")
return
else:
yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time()
output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
original_question = question
if not (shared.args.chat or shared.args.cai_chat):
@ -113,24 +116,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
eos_token_ids = [shared.tokenizer.eos_token_id]
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None:
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0])
)
])
else:
stopping_criteria_list = None
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
if not shared.args.flexgen:
generate_params = [
f"eos_token_id={n}",
f"max_new_tokens=max_new_tokens",
f"eos_token_id={eos_token_ids}",
f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}",
f"temperature={temperature}",
@ -147,44 +148,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
]
else:
generate_params = [
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}",
f"temperature={temperature}",
f"stop={n}",
f"stop={eos_token_ids[-1]}",
]
if shared.args.deepspeed:
generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append("max_new_tokens=max_new_tokens")
else:
generate_params.append("max_new_tokens=8")
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids")
generate_params.insert(0, "inputs=filler_input_ids")
else:
generate_params.insert(0, "input_ids")
# Generate the entire reply at once
if shared.args.no_stream:
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
yield formatted_outputs(reply, shared.model_name)
# Generate the reply 8 tokens at a time
else:
yield formatted_outputs(original_question, shared.model_name)
for i in tqdm(range(max_new_tokens//8+1)):
clear_torch_cache()
generate_params.insert(0, "inputs=input_ids")
try:
# Generate the entire reply at once.
if shared.args.no_stream:
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
@ -193,16 +173,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
if not shared.args.flexgen:
if output[-1] == n:
break
input_ids = torch.reshape(output, (1, output.shape[0]))
else:
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
input_ids = np.reshape(output, (1, output.shape[0]))
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
elif not shared.args.flexgen:
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
with torch.no_grad():
shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
shared.still_streaming = True
yield formatted_outputs(original_question, shared.model_name)
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
shared.still_streaming = True
for i in range(max_new_tokens//8+1):
clear_torch_cache()
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
return

View File

@ -1,9 +1,11 @@
accelerate==0.16.0
accelerate==0.17.0
bitsandbytes==0.37.0
flexgen==0.1.7
gradio==3.18.0
numpy
rwkv==0.1.0
safetensors==0.2.8
requests
rwkv==0.3.1
safetensors==0.3.0
sentencepiece
git+https://github.com/oobabooga/transformers@llama_push
tqdm
git+https://github.com/zphang/transformers@llama_push

View File

@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings
settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -37,7 +34,7 @@ def get_available_models():
if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else:
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
@ -372,9 +370,9 @@ else:
shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port)
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
# I think that I will need this later
while True:

View File

@ -29,6 +29,7 @@
"prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n"
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
}
}