diff --git a/docs/DeepSpeed.md b/docs/DeepSpeed.md index 70cd8151..6170f681 100644 --- a/docs/DeepSpeed.md +++ b/docs/DeepSpeed.md @@ -9,7 +9,8 @@ As far as I know, DeepSpeed is only available for Linux at the moment. 1. Install DeepSpeed: ``` -pip install deepspeed +conda install -c conda-forge mpi4py mpich +pip install -U deepspeed ``` 2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example: diff --git a/modules/models.py b/modules/models.py index 0bbc574c..5696b624 100644 --- a/modules/models.py +++ b/modules/models.py @@ -20,9 +20,6 @@ from modules import llama_attn_hijack transformers.logging.set_verbosity_error() -if shared.args.flexgen: - from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy - local_rank = None if shared.args.deepspeed: import deepspeed @@ -40,6 +37,8 @@ if shared.args.deepspeed: dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration +# Some models require special treatment in various parts of the code. +# This function detects those models def find_model_type(model_name): model_name_lower = model_name.lower() if 'rwkv-' in model_name_lower: @@ -72,7 +71,60 @@ def load_model(model_name): t0 = time.time() shared.model_type = find_model_type(model_name) - trust_remote_code = shared.args.trust_remote_code + if shared.args.wbits > 0: + load_func = GPTQ_loader + elif shared.model_type == 'llamacpp': + load_func = llamacpp_loader + elif shared.model_type == 'rwkv': + load_func = RWKV_loader + elif shared.args.flexgen: + load_func = flexgen_loader + else: + load_func = huggingface_loader + + output = load_func(model_name) + if type(output) is tuple: + model, tokenizer = output + else: + model = output + tokenizer = load_tokenizer(model_name, model) + + # Hijack attention with xformers + if any((shared.args.xformers, shared.args.sdp_attention)): + llama_attn_hijack.hijack_llama_attention() + + logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n") + return model, tokenizer + + +def load_tokenizer(model_name, model): + if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) + elif type(model) is transformers.LlamaForCausalLM: + # Try to load an universal LLaMA tokenizer + if shared.model_type not in ['llava', 'oasst']: + for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: + if p.exists(): + logging.info(f"Loading the universal LLaMA tokenizer from {p}...") + tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) + return tokenizer + + # Otherwise, load it from the model folder and hope that these + # are not outdated tokenizer files. + tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True) + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except: + pass + else: + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code) + + return tokenizer + + +def huggingface_loader(model_name): if shared.model_type == 'chatglm': LoaderClass = AutoModel elif shared.model_type == 'HF_seq2seq': @@ -81,37 +133,14 @@ def load_model(model_name): LoaderClass = AutoModelForCausalLM # Load the model in simple 16-bit mode by default - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]): - model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code) + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]): + model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code) if torch.has_mps: device = torch.device('mps') model = model.to(device) else: model = model.cuda() - # FlexGen - elif shared.args.flexgen: - # Initialize environment - env = ExecutionEnv.create(shared.args.disk_cache_dir) - - # Offloading policy - policy = Policy(1, 1, - shared.args.percent[0], shared.args.percent[1], - shared.args.percent[2], shared.args.percent[3], - shared.args.percent[4], shared.args.percent[5], - overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight, - cpu_cache_compute=False, attn_sparsity=1.0, - compress_weight=shared.args.compress_weight, - comp_weight_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=0, symmetric=False), - compress_cache=False, - comp_cache_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=2, symmetric=False)) - - model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy) - # DeepSpeed ZeRO-3 elif shared.args.deepspeed: model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) @@ -119,50 +148,11 @@ def load_model(model_name): model.module.eval() # Inference logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") - # RMKV model (not on HuggingFace) - elif shared.model_type == 'rwkv': - from modules.RWKV import RWKVModel, RWKVTokenizer - - model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) - - return model, tokenizer - - # llamacpp model - elif shared.model_type == 'llamacpp': - from modules.llamacpp_model import LlamaCppModel - - path = Path(f'{shared.args.model_dir}/{model_name}') - if path.is_file(): - model_file = path - else: - model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] - - logging.info(f"llama.cpp weights detected: {model_file}\n") - model, tokenizer = LlamaCppModel.from_pretrained(model_file) - return model, tokenizer - - # Quantized model - elif shared.args.wbits > 0: - - # Monkey patch - if shared.args.monkey_patch: - logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.") - from modules.monkey_patch_gptq_lora import load_model_llama - - model, _ = load_model_llama(model_name) - - # No monkey patch - else: - from modules.GPTQ_loader import load_quantized - - model = load_quantized(model_name) - # Custom else: params = { "low_cpu_mem_usage": True, - "trust_remote_code": trust_remote_code + "trust_remote_code": shared.args.trust_remote_code } if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): @@ -188,9 +178,9 @@ def load_model(model_name): checkpoint = Path(f'{shared.args.model_dir}/{model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': - config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code) with init_empty_weights(): - model = LoaderClass.from_config(config, trust_remote_code=trust_remote_code) + model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code) model.tie_weights() params['device_map'] = infer_auto_device_map( @@ -202,44 +192,77 @@ def load_model(model_name): model = LoaderClass.from_pretrained(checkpoint, **params) - # Hijack attention with xformers - if any((shared.args.xformers, shared.args.sdp_attention)): - llama_attn_hijack.hijack_llama_attention() + return model - # Loading the tokenizer - if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): - tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) - elif type(model) is transformers.LlamaForCausalLM: - tokenizer = None - # Try to load an universal LLaMA tokenizer - if shared.model_type not in ['llava', 'oasst']: - for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: - if p.exists(): - logging.info(f"Loading the universal LLaMA tokenizer from {p}...") - tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) - break +def flexgen_loader(model_name): + from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy - # Otherwise, load it from the model folder and hope that these - # are not outdated tokenizer files. - if tokenizer is None: - tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True) - try: - tokenizer.eos_token_id = 2 - tokenizer.bos_token_id = 1 - tokenizer.pad_token_id = 0 - except: - pass - else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code) + # Initialize environment + env = ExecutionEnv.create(shared.args.disk_cache_dir) - logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n") + # Offloading policy + policy = Policy(1, 1, + shared.args.percent[0], shared.args.percent[1], + shared.args.percent[2], shared.args.percent[3], + shared.args.percent[4], shared.args.percent[5], + overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight, + cpu_cache_compute=False, attn_sparsity=1.0, + compress_weight=shared.args.compress_weight, + comp_weight_config=CompressionConfig( + num_bits=4, group_size=64, + group_dim=0, symmetric=False), + compress_cache=False, + comp_cache_config=CompressionConfig( + num_bits=4, group_size=64, + group_dim=2, symmetric=False)) + + model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy) + return model + + +def RWKV_loader(model_name): + from modules.RWKV import RWKVModel, RWKVTokenizer + + model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer +def llamacpp_loader(model_name): + from modules.llamacpp_model import LlamaCppModel + + path = Path(f'{shared.args.model_dir}/{model_name}') + if path.is_file(): + model_file = path + else: + model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] + + logging.info(f"llama.cpp weights detected: {model_file}\n") + model, tokenizer = LlamaCppModel.from_pretrained(model_file) + return model, tokenizer + + +def GPTQ_loader(model_name): + + # Monkey patch + if shared.args.monkey_patch: + logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.") + from modules.monkey_patch_gptq_lora import load_model_llama + + model, _ = load_model_llama(model_name) + + # No monkey patch + else: + from modules.GPTQ_loader import load_quantized + + model = load_quantized(model_name) + + return model + + def get_max_memory_dict(): max_memory = {} - if shared.args.gpu_memory: memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) for i in range(len(memory_map)): diff --git a/server.py b/server.py index b3b1219c..af096865 100644 --- a/server.py +++ b/server.py @@ -456,8 +456,8 @@ def create_settings_menus(default_preset): shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') with gr.Column(): - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')