diff --git a/README.md b/README.md index 65596321..dc5ed659 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Text generation web UI -A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion. +A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. @@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. +* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * Supports softprompts. * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). @@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU. pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 ``` -* If you are running in CPU mode, replace the third command with this one: +* If you are running it in CPU mode, replace the third command with this one: ``` conda install pytorch torchvision torchaudio git -c pytorch @@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags: | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| +| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.| +| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | @@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System- Pull requests, suggestions, and issue reports are welcome. -Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above. +Before reporting a bug, make sure that you have: -These issues are known: - -* 8-bit doesn't work properly on Windows or older GPUs. -* DeepSpeed doesn't work properly on Windows. - -For these two, please try commenting on an existing issue instead of creating a new one. +1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above. +2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered. ## Credits diff --git a/download-model.py b/download-model.py index 599418fc..8be398c4 100644 --- a/download-model.py +++ b/download-model.py @@ -5,7 +5,9 @@ Example: python download-model.py facebook/opt-1.3b ''' + import argparse +import base64 import json import multiprocessing import re @@ -93,23 +95,28 @@ facebook/opt-1.3b def get_download_links_from_huggingface(model, branch): base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}?cursor=" + cursor = b"" links = [] classifications = [] has_pytorch = False has_safetensors = False - while page is not None: - content = requests.get(f"{base}{page}").content + while True: + content = requests.get(f"{base}{page}{cursor.decode()}").content + dict = json.loads(content) + if len(dict) == 0: + break for i in range(len(dict)): fname = dict[i]['path'] is_pytorch = re.match("pytorch_model.*\.bin", fname) is_safetensors = re.match("model.*\.safetensors", fname) - is_text = re.match(".*\.(txt|json)", fname) + is_tokenizer = re.match("tokenizer.*\.model", fname) + is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer - if is_text or is_safetensors or is_pytorch: + if any((is_pytorch, is_safetensors, is_text, is_tokenizer)): if is_text: links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") classifications.append('text') @@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch): has_pytorch = True classifications.append('pytorch') - #page = dict['nextUrl'] - page = None + cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' + cursor = base64.b64encode(cursor) + cursor = cursor.replace(b'=', b'%3D') # If both pytorch and safetensors are available, download safetensors only if has_pytorch and has_safetensors: diff --git a/extensions/llama_prompts/script.py b/extensions/llama_prompts/script.py new file mode 100644 index 00000000..22c96f7c --- /dev/null +++ b/extensions/llama_prompts/script.py @@ -0,0 +1,18 @@ +import gradio as gr +import modules.shared as shared +import pandas as pd + +df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") + +def get_prompt_by_name(name): + if name == 'None': + return '' + else: + return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') + +def ui(): + if not shared.args.chat or shared.args.cai_chat: + choices = ['None'] + list(df['Prompt name']) + + prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt') + prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox']) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 050392d6..62d4b441 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -1,21 +1,45 @@ +import re +import time from pathlib import Path import gradio as gr import torch +import modules.chat as chat +import modules.shared as shared + torch._C._jit_set_profiling_mode(False) params = { 'activate': True, - 'speaker': 'en_56', + 'speaker': 'en_5', 'language': 'en', 'model_id': 'v3_en', 'sample_rate': 48000, 'device': 'cpu', + 'show_text': False, + 'autoplay': True, + 'voice_pitch': 'medium', + 'voice_speed': 'medium', } + current_params = params.copy() voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] -wav_idx = 0 +voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] +voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] +last_msg_id = 0 + +# Used for making text xml compatible, needed for voice pitch and speed control +table = str.maketrans({ + "<": "<", + ">": ">", + "&": "&", + "'": "'", + '"': """, +}) + +def xmlesc(txt): + return txt.translate(table) def load_model(): model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) @@ -33,12 +57,59 @@ def remove_surrounded_chars(string): new_string += char return new_string +def remove_tts_from_history(): + suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' + for i, entry in enumerate(shared.history['internal']): + reply = entry[1] + reply = re.sub("(||{{user}})", shared.settings[f'name1{suffix}'], reply) + if shared.args.chat: + reply = reply.replace('\n', '
') + shared.history['visible'][i][1] = reply + + if shared.args.cai_chat: + return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character) + else: + return shared.history['visible'] + +def toggle_text_in_history(): + suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' + audio_str='\n\n' # The '\n\n' used after + if shared.args.chat: + audio_str='

' + + if params['show_text']==True: + #for i, entry in enumerate(shared.history['internal']): + for i, entry in enumerate(shared.history['visible']): + vis_reply = entry[1] + if vis_reply.startswith('||{{user}})", shared.settings[f'name1{suffix}'], reply) + if shared.args.chat: + reply = reply.replace('\n', '
') + shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply + else: + for i, entry in enumerate(shared.history['visible']): + vis_reply = entry[1] + if vis_reply.startswith('0: + [visible_text, visible_reply] = shared.history['visible'][-1] + vis_rep_clean = visible_reply.replace('controls autoplay>','controls>') + shared.history['visible'][-1] = [visible_text, vis_rep_clean] + return string def output_modifier(string): @@ -46,7 +117,7 @@ def output_modifier(string): This function is applied to the model outputs. """ - global wav_idx, model, current_params + global model, current_params for i in params: if params[i] != current_params[i]: @@ -57,20 +128,34 @@ def output_modifier(string): if params['activate'] == False: return string + orig_string = string string = remove_surrounded_chars(string) string = string.replace('"', '') string = string.replace('“', '') string = string.replace('\n', ' ') string = string.strip() + silent_string = False # Used to prevent unnecessary audio file generation if string == '': string = 'empty reply, try regenerating' + silent_string = True - output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') - model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + pitch = params['voice_pitch'] + speed = params['voice_speed'] + prosody=f'' + string = ''+prosody+xmlesc(string)+'' - string = f'' - wav_idx += 1 + if not shared.still_streaming and not silent_string: + output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav') + model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + autoplay_str = ' autoplay' if params['autoplay'] else '' + string = f'\n\n' + else: + # Placeholder so text doesn't shift around so much + string = '\n\n' + + if params['show_text']: + string += orig_string return string @@ -85,9 +170,36 @@ def bot_prefix_modifier(string): def ui(): # Gradio elements - activate = gr.Checkbox(value=params['activate'], label='Activate TTS') - voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') + with gr.Accordion("Silero TTS"): + with gr.Row(): + activate = gr.Checkbox(value=params['activate'], label='Activate TTS') + autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') + show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') + voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') + with gr.Row(): + v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') + v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') + with gr.Row(): + convert = gr.Button('Permanently replace chat history audio with message text') + convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) + convert_cancel = gr.Button('Cancel', visible=False) + + # Convert history with confirmation + convert_arr = [convert_confirm, convert, convert_cancel] + convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) + convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) + convert_confirm.click(remove_tts_from_history, [], shared.gradio['display']) + convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) + + # Toggle message text in history + show_text.change(lambda x: params.update({"show_text": x}), show_text, None) + show_text.change(toggle_text_in_history, [], shared.gradio['display']) + show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) # Event functions to update the parameters in the backend activate.change(lambda x: params.update({"activate": x}), activate, None) + autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None) voice.change(lambda x: params.update({"speaker": x}), voice, None) + v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None) + v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None) diff --git a/modules/RWKV.py b/modules/RWKV.py index b226a195..d97c1706 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -1,12 +1,11 @@ import os from pathlib import Path -from queue import Queue -from threading import Thread import numpy as np from tokenizers import Tokenizer import modules.shared as shared +from modules.callbacks import Iteratorize np.set_printoptions(precision=4, suppress=True, linewidth=200) @@ -49,11 +48,11 @@ class RWKVModel: return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) def generate_with_streaming(self, **kwargs): - iterable = Iteratorize(self.generate, kwargs, callback=None) - reply = kwargs['context'] - for token in iterable: - reply += token - yield reply + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = kwargs['context'] + for token in generator: + reply += token + yield reply class RWKVTokenizer: def __init__(self): @@ -73,38 +72,3 @@ class RWKVTokenizer: def decode(self, ids): return self.tokenizer.decode(ids) - -class Iteratorize: - - """ - Transforms a function that takes a callback - into a lazy iterator (generator). - """ - - def __init__(self, func, kwargs={}, callback=None): - self.mfunc=func - self.c_callback=callback - self.q = Queue(maxsize=1) - self.sentinel = object() - self.kwargs = kwargs - - def _callback(val): - self.q.put(val) - - def gentask(): - ret = self.mfunc(callback=_callback, **self.kwargs) - self.q.put(self.sentinel) - if self.c_callback: - self.c_callback(ret) - - Thread(target=gentask).start() - - def __iter__(self): - return self - - def __next__(self): - obj = self.q.get(True,None) - if obj is self.sentinel: - raise StopIteration - else: - return obj diff --git a/modules/callbacks.py b/modules/callbacks.py new file mode 100644 index 00000000..faa4a5e9 --- /dev/null +++ b/modules/callbacks.py @@ -0,0 +1,98 @@ +import gc +from queue import Queue +from threading import Thread + +import torch +import transformers + +import modules.shared as shared + +# Copied from https://github.com/PygmalionAI/gradio-ui/ +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: torch.LongTensor, + starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, + _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: + continue + + for window in trimmed_sample.unfold( + 0, self.sentinel_token_ids.shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids, window)): + return True + return False + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc=func + self.c_callback=callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + clear_torch_cache() + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True,None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + clear_torch_cache() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True + clear_torch_cache() + +def clear_torch_cache(): + gc.collect() + if not shared.args.cpu: + torch.cuda.empty_cache() diff --git a/modules/chat.py b/modules/chat.py index f40f8299..2048e2c5 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate tmp = f"\n{asker}:" for j in range(1, len(tmp)): if reply[-j:] == tmp[:j]: + reply = reply[:-j] substring_found = True return reply, next_character_found, substring_found @@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical else: prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + if not regenerate: + # Display user input and "*is typing...*" imediately + yield shared.history['visible']+[[visible_text, '*Is typing...*']] + # Generate reply = '' for i in range(chat_generation_attempts): @@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) + # Display "*is typing...*" imediately + yield '*Is typing...*' + reply = '' for i in range(chat_generation_attempts): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): @@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -291,7 +299,7 @@ def save_history(timestamp=True): fname = f"{prefix}persistent.json" if not Path('logs').exists(): Path('logs').mkdir() - with open(Path(f'logs/{fname}'), 'w') as f: + with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) return Path(f'logs/{fname}') @@ -332,7 +340,7 @@ def load_character(_character, name1, name2): shared.history['visible'] = [] if _character != 'None': shared.character = _character - data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) + data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read()) name2 = data['char_name'] if 'char_persona' in data and data['char_persona'] != '': context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" @@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False): i += 1 if tavern: outfile_name = f'TavernAI-{outfile_name}' - with open(Path(f'characters/{outfile_name}.json'), 'w') as f: + with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f: f.write(json_file) if img is not None: img = Image.open(io.BytesIO(img)) diff --git a/modules/models.py b/modules/models.py index 16ce6eb1..7d094ed5 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,5 +1,6 @@ import json import os +import sys import time import zipfile from pathlib import Path @@ -41,7 +42,7 @@ def load_model(model_name): shared.is_RWKV = model_name.lower().startswith('rwkv-') # Default settings - if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV): + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: @@ -86,6 +87,12 @@ def load_model(model_name): return model, tokenizer + # 4-bit LLaMA + elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit: + from modules.quantized_LLaMA import load_quantized_LLaMA + + model = load_quantized_LLaMA(model_name) + # Custom else: command = "AutoModelForCausalLM.from_pretrained" diff --git a/modules/quantized_LLaMA.py b/modules/quantized_LLaMA.py new file mode 100644 index 00000000..5e4a38e8 --- /dev/null +++ b/modules/quantized_LLaMA.py @@ -0,0 +1,60 @@ +import os +import sys +from pathlib import Path + +import accelerate +import torch + +import modules.shared as shared + +sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa"))) +from llama import load_quant + + +# 4-bit LLaMA +def load_quantized_LLaMA(model_name): + if shared.args.load_in_4bit: + bits = 4 + else: + bits = shared.args.gptq_bits + + path_to_model = Path(f'models/{model_name}') + pt_model = '' + if path_to_model.name.lower().startswith('llama-7b'): + pt_model = f'llama-7b-{bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-13b'): + pt_model = f'llama-13b-{bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-30b'): + pt_model = f'llama-30b-{bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-65b'): + pt_model = f'llama-65b-{bits}bit.pt' + else: + pt_model = f'{model_name}-{bits}bit.pt' + + # Try to find the .pt both in models/ and in the subfolder + pt_path = None + for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: + if path.exists(): + pt_path = path + + if not pt_path: + print(f"Could not find {pt_model}, exiting...") + exit() + + model = load_quant(path_to_model, os.path.abspath(pt_path), bits) + + # Multi-GPU setup + if shared.args.gpu_memory: + max_memory = {} + for i in range(len(shared.args.gpu_memory)): + max_memory[i] = f"{shared.args.gpu_memory[i]}GiB" + max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB" + + device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"]) + model = accelerate.dispatch_model(model, device_map=device_map) + + # Single GPU + else: + model = model.to(torch.device('cuda:0')) + + return model diff --git a/modules/shared.py b/modules/shared.py index b609045c..a06c9774 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ is_RWKV = False history = {'internal': [], 'visible': []} character = 'None' stop_everything = False +still_streaming = False # UI elements (buttons, sliders, HTML, etc) gradio = {} @@ -42,12 +43,12 @@ settings = { 'default': 'NovelAI-Sphinx Moth', 'pygmalion-*': 'Pygmalion', 'RWKV-*': 'Naive', - '(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)' }, 'prompts': { 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', - '(rosey|chip|joi)_.*_instruct.*': 'User: \n' + '(rosey|chip|joi)_.*_instruct.*': 'User: \n', + 'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>' } } @@ -68,6 +69,8 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') +parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.') +parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') @@ -90,4 +93,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch') args = parser.parse_args() diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py deleted file mode 100644 index 44a631b3..00000000 --- a/modules/stopping_criteria.py +++ /dev/null @@ -1,32 +0,0 @@ -''' -This code was copied from - -https://github.com/PygmalionAI/gradio-ui/ - -''' - -import torch -import transformers - - -class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): - - def __init__(self, sentinel_token_ids: torch.LongTensor, - starting_idx: int): - transformers.StoppingCriteria.__init__(self) - self.sentinel_token_ids = sentinel_token_ids - self.starting_idx = starting_idx - - def __call__(self, input_ids: torch.LongTensor, - _scores: torch.FloatTensor) -> bool: - for sample in input_ids: - trimmed_sample = sample[self.starting_idx:] - # Can't unfold, output is still too tiny. Skip. - if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: - continue - - for window in trimmed_sample.unfold( - 0, self.sentinel_token_ids.shape[-1], 1): - if torch.all(torch.eq(self.sentinel_token_ids, window)): - return True - return False diff --git a/modules/text_generation.py b/modules/text_generation.py index 4af53273..7cf68c06 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -5,13 +5,13 @@ import time import numpy as np import torch import transformers -from tqdm import tqdm import modules.shared as shared +from modules.callbacks import (Iteratorize, Stream, + _SentinelTokenStoppingCriteria) from modules.extensions import apply_extensions from modules.html_generator import generate_4chan_html, generate_basic_html from modules.models import local_rank -from modules.stopping_criteria import _SentinelTokenStoppingCriteria def get_max_prompt_length(tokens): @@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if shared.is_RWKV: - if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) - yield formatted_outputs(reply, shared.model_name) - else: - yield formatted_outputs(question, shared.model_name) - # RWKV has proper streaming, which is very nice. - # No need to generate 8 tokens at a time. - for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + try: + if shared.args.no_stream: + reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds.") - return + else: + yield formatted_outputs(question, shared.model_name) + # RWKV has proper streaming, which is very nice. + # No need to generate 8 tokens at a time. + for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + yield formatted_outputs(reply, shared.model_name) + finally: + t1 = time.time() + output = encode(reply)[0] + input_ids = encode(question) + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") + return original_question = question if not (shared.args.chat or shared.args.cai_chat): @@ -113,24 +116,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"\n\n{question}\n--------------------\n") input_ids = encode(question, max_new_tokens) + original_input_ids = input_ids + output = input_ids[0] cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" - n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) + eos_token_ids = [shared.tokenizer.eos_token_id] + if eos_token is not None: + eos_token_ids.append(int(encode(eos_token)[0][-1])) + stopping_criteria_list = transformers.StoppingCriteriaList() if stopping_string is not None: - # The stopping_criteria code below was copied from - # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py + # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py t = encode(stopping_string, 0, add_special_tokens=False) - stopping_criteria_list = transformers.StoppingCriteriaList([ - _SentinelTokenStoppingCriteria( - sentinel_token_ids=t, - starting_idx=len(input_ids[0]) - ) - ]) - else: - stopping_criteria_list = None + stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) if not shared.args.flexgen: generate_params = [ - f"eos_token_id={n}", + f"max_new_tokens=max_new_tokens", + f"eos_token_id={eos_token_ids}", f"stopping_criteria=stopping_criteria_list", f"do_sample={do_sample}", f"temperature={temperature}", @@ -147,44 +148,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi ] else: generate_params = [ + f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", f"do_sample={do_sample}", f"temperature={temperature}", - f"stop={n}", + f"stop={eos_token_ids[-1]}", ] if shared.args.deepspeed: generate_params.append("synced_gpus=True") - if shared.args.no_stream: - generate_params.append("max_new_tokens=max_new_tokens") - else: - generate_params.append("max_new_tokens=8") if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "filler_input_ids") + generate_params.insert(0, "inputs=filler_input_ids") else: - generate_params.insert(0, "input_ids") - - # Generate the entire reply at once - if shared.args.no_stream: - with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") - yield formatted_outputs(reply, shared.model_name) - - # Generate the reply 8 tokens at a time - else: - yield formatted_outputs(original_question, shared.model_name) - for i in tqdm(range(max_new_tokens//8+1)): - clear_torch_cache() + generate_params.insert(0, "inputs=input_ids") + try: + # Generate the entire reply at once. + if shared.args.no_stream: with torch.no_grad(): output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] if shared.soft_prompt: @@ -193,16 +173,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = decode(output) if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) - if not shared.args.flexgen: - if output[-1] == n: - break - input_ids = torch.reshape(output, (1, output.shape[0])) - else: - if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): - break - input_ids = np.reshape(output, (1, output.shape[0])) + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator. + elif not shared.args.flexgen: - if shared.soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + def generate_with_callback(callback=None, **kwargs): + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + clear_torch_cache() + with torch.no_grad(): + shared.model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + shared.still_streaming = True + yield formatted_outputs(original_question, shared.model_name) + with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + for output in generator: + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) + + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + + if output[-1] in eos_token_ids: + break + yield formatted_outputs(reply, shared.model_name) + + shared.still_streaming = False + yield formatted_outputs(reply, shared.model_name) + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + shared.still_streaming = True + for i in range(max_new_tokens//8+1): + clear_torch_cache() + with torch.no_grad(): + output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) + + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + + if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): + break + yield formatted_outputs(reply, shared.model_name) + + input_ids = np.reshape(output, (1, output.shape[0])) + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + + shared.still_streaming = False + yield formatted_outputs(reply, shared.model_name) + + finally: + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") + return diff --git a/requirements.txt b/requirements.txt index 47c56a45..b078ecf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ -accelerate==0.16.0 +accelerate==0.17.0 bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy -rwkv==0.1.0 -safetensors==0.2.8 +requests +rwkv==0.3.1 +safetensors==0.3.0 sentencepiece -git+https://github.com/oobabooga/transformers@llama_push +tqdm +git+https://github.com/zphang/transformers@llama_push diff --git a/server.py b/server.py index 7d8792b7..47e9c8ed 100644 --- a/server.py +++ b/server.py @@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply -if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: - print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n') - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -37,7 +34,7 @@ def get_available_models(): if shared.args.flexgen: return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) + return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) @@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat: function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False)) + gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False)) + gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) @@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat: reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) @@ -372,9 +370,9 @@ else: shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) # I think that I will need this later while True: diff --git a/settings-template.json b/settings-template.json index 6585f313..9da43970 100644 --- a/settings-template.json +++ b/settings-template.json @@ -29,6 +29,7 @@ "prompts": { "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", - "(rosey|chip|joi)_.*_instruct.*": "User: \n" + "(rosey|chip|joi)_.*_instruct.*": "User: \n", + "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" } }