Refactor models.py (#2113)

This commit is contained in:
oobabooga 2023-05-16 19:52:22 -03:00 committed by GitHub
parent 5cd6dd4287
commit 7584d46c29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 102 deletions

View File

@ -9,7 +9,8 @@ As far as I know, DeepSpeed is only available for Linux at the moment.
1. Install DeepSpeed: 1. Install DeepSpeed:
``` ```
pip install deepspeed conda install -c conda-forge mpi4py mpich
pip install -U deepspeed
``` ```
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example: 2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:

View File

@ -20,9 +20,6 @@ from modules import llama_attn_hijack
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
if shared.args.flexgen:
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
local_rank = None local_rank = None
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
@ -40,6 +37,8 @@ if shared.args.deepspeed:
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
# Some models require special treatment in various parts of the code.
# This function detects those models
def find_model_type(model_name): def find_model_type(model_name):
model_name_lower = model_name.lower() model_name_lower = model_name.lower()
if 'rwkv-' in model_name_lower: if 'rwkv-' in model_name_lower:
@ -72,7 +71,60 @@ def load_model(model_name):
t0 = time.time() t0 = time.time()
shared.model_type = find_model_type(model_name) shared.model_type = find_model_type(model_name)
trust_remote_code = shared.args.trust_remote_code if shared.args.wbits > 0:
load_func = GPTQ_loader
elif shared.model_type == 'llamacpp':
load_func = llamacpp_loader
elif shared.model_type == 'rwkv':
load_func = RWKV_loader
elif shared.args.flexgen:
load_func = flexgen_loader
else:
load_func = huggingface_loader
output = load_func(model_name)
if type(output) is tuple:
model, tokenizer = output
else:
model = output
tokenizer = load_tokenizer(model_name, model)
# Hijack attention with xformers
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
return model, tokenizer
def load_tokenizer(model_name, model):
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
# Try to load an universal LLaMA tokenizer
if shared.model_type not in ['llava', 'oasst']:
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
return tokenizer
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
return tokenizer
def huggingface_loader(model_name):
if shared.model_type == 'chatglm': if shared.model_type == 'chatglm':
LoaderClass = AutoModel LoaderClass = AutoModel
elif shared.model_type == 'HF_seq2seq': elif shared.model_type == 'HF_seq2seq':
@ -81,37 +133,14 @@ def load_model(model_name):
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default # Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code) model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
if torch.has_mps: if torch.has_mps:
device = torch.device('mps') device = torch.device('mps')
model = model.to(device) model = model.to(device)
else: else:
model = model.cuda() model = model.cuda()
# FlexGen
elif shared.args.flexgen:
# Initialize environment
env = ExecutionEnv.create(shared.args.disk_cache_dir)
# Offloading policy
policy = Policy(1, 1,
shared.args.percent[0], shared.args.percent[1],
shared.args.percent[2], shared.args.percent[3],
shared.args.percent[4], shared.args.percent[5],
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
cpu_cache_compute=False, attn_sparsity=1.0,
compress_weight=shared.args.compress_weight,
comp_weight_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=0, symmetric=False),
compress_cache=False,
comp_cache_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
@ -119,50 +148,11 @@ def load_model(model_name):
model.module.eval() # Inference model.module.eval() # Inference
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# RMKV model (not on HuggingFace)
elif shared.model_type == 'rwkv':
from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer
# llamacpp model
elif shared.model_type == 'llamacpp':
from modules.llamacpp_model import LlamaCppModel
path = Path(f'{shared.args.model_dir}/{model_name}')
if path.is_file():
model_file = path
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
logging.info(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer
# Quantized model
elif shared.args.wbits > 0:
# Monkey patch
if shared.args.monkey_patch:
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name)
# No monkey patch
else:
from modules.GPTQ_loader import load_quantized
model = load_quantized(model_name)
# Custom # Custom
else: else:
params = { params = {
"low_cpu_mem_usage": True, "low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code "trust_remote_code": shared.args.trust_remote_code
} }
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
@ -188,9 +178,9 @@ def load_model(model_name):
checkpoint = Path(f'{shared.args.model_dir}/{model_name}') checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
with init_empty_weights(): with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=trust_remote_code) model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
model.tie_weights() model.tie_weights()
params['device_map'] = infer_auto_device_map( params['device_map'] = infer_auto_device_map(
@ -202,44 +192,77 @@ def load_model(model_name):
model = LoaderClass.from_pretrained(checkpoint, **params) model = LoaderClass.from_pretrained(checkpoint, **params)
# Hijack attention with xformers return model
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
# Loading the tokenizer
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
tokenizer = None
# Try to load an universal LLaMA tokenizer def flexgen_loader(model_name):
if shared.model_type not in ['llava', 'oasst']: from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
# Otherwise, load it from the model folder and hope that these # Initialize environment
# are not outdated tokenizer files. env = ExecutionEnv.create(shared.args.disk_cache_dir)
if tokenizer is None:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n") # Offloading policy
policy = Policy(1, 1,
shared.args.percent[0], shared.args.percent[1],
shared.args.percent[2], shared.args.percent[3],
shared.args.percent[4], shared.args.percent[5],
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
cpu_cache_compute=False, attn_sparsity=1.0,
compress_weight=shared.args.compress_weight,
comp_weight_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=0, symmetric=False),
compress_cache=False,
comp_cache_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
return model
def RWKV_loader(model_name):
from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer return model, tokenizer
def llamacpp_loader(model_name):
from modules.llamacpp_model import LlamaCppModel
path = Path(f'{shared.args.model_dir}/{model_name}')
if path.is_file():
model_file = path
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
logging.info(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer
def GPTQ_loader(model_name):
# Monkey patch
if shared.args.monkey_patch:
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name)
# No monkey patch
else:
from modules.GPTQ_loader import load_quantized
model = load_quantized(model_name)
return model
def get_max_memory_dict(): def get_max_memory_dict():
max_memory = {} max_memory = {}
if shared.args.gpu_memory: if shared.args.gpu_memory:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
for i in range(len(memory_map)): for i in range(len(memory_map)):

View File

@ -456,8 +456,8 @@ def create_settings_menus(default_preset):
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
with gr.Column(): with gr.Column():
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')