mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-23 10:09:20 +01:00
Refactor models.py (#2113)
This commit is contained in:
parent
5cd6dd4287
commit
7584d46c29
@ -9,7 +9,8 @@ As far as I know, DeepSpeed is only available for Linux at the moment.
|
||||
1. Install DeepSpeed:
|
||||
|
||||
```
|
||||
pip install deepspeed
|
||||
conda install -c conda-forge mpi4py mpich
|
||||
pip install -U deepspeed
|
||||
```
|
||||
|
||||
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:
|
||||
|
@ -20,9 +20,6 @@ from modules import llama_attn_hijack
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
@ -40,6 +37,8 @@ if shared.args.deepspeed:
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
|
||||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
def find_model_type(model_name):
|
||||
model_name_lower = model_name.lower()
|
||||
if 'rwkv-' in model_name_lower:
|
||||
@ -72,7 +71,60 @@ def load_model(model_name):
|
||||
t0 = time.time()
|
||||
|
||||
shared.model_type = find_model_type(model_name)
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
if shared.args.wbits > 0:
|
||||
load_func = GPTQ_loader
|
||||
elif shared.model_type == 'llamacpp':
|
||||
load_func = llamacpp_loader
|
||||
elif shared.model_type == 'rwkv':
|
||||
load_func = RWKV_loader
|
||||
elif shared.args.flexgen:
|
||||
load_func = flexgen_loader
|
||||
else:
|
||||
load_func = huggingface_loader
|
||||
|
||||
output = load_func(model_name)
|
||||
if type(output) is tuple:
|
||||
model, tokenizer = output
|
||||
else:
|
||||
model = output
|
||||
tokenizer = load_tokenizer(model_name, model)
|
||||
|
||||
# Hijack attention with xformers
|
||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||
llama_attn_hijack.hijack_llama_attention()
|
||||
|
||||
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_tokenizer(model_name, model):
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if shared.model_type not in ['llava', 'oasst']:
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
return tokenizer
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def huggingface_loader(model_name):
|
||||
if shared.model_type == 'chatglm':
|
||||
LoaderClass = AutoModel
|
||||
elif shared.model_type == 'HF_seq2seq':
|
||||
@ -81,37 +133,14 @@ def load_model(model_name):
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
# FlexGen
|
||||
elif shared.args.flexgen:
|
||||
# Initialize environment
|
||||
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
||||
|
||||
# Offloading policy
|
||||
policy = Policy(1, 1,
|
||||
shared.args.percent[0], shared.args.percent[1],
|
||||
shared.args.percent[2], shared.args.percent[3],
|
||||
shared.args.percent[4], shared.args.percent[5],
|
||||
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
|
||||
cpu_cache_compute=False, attn_sparsity=1.0,
|
||||
compress_weight=shared.args.compress_weight,
|
||||
comp_weight_config=CompressionConfig(
|
||||
num_bits=4, group_size=64,
|
||||
group_dim=0, symmetric=False),
|
||||
compress_cache=False,
|
||||
comp_cache_config=CompressionConfig(
|
||||
num_bits=4, group_size=64,
|
||||
group_dim=2, symmetric=False))
|
||||
|
||||
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
@ -119,50 +148,11 @@ def load_model(model_name):
|
||||
model.module.eval() # Inference
|
||||
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
|
||||
# RMKV model (not on HuggingFace)
|
||||
elif shared.model_type == 'rwkv':
|
||||
from modules.RWKV import RWKVModel, RWKVTokenizer
|
||||
|
||||
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
||||
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
# llamacpp model
|
||||
elif shared.model_type == 'llamacpp':
|
||||
from modules.llamacpp_model import LlamaCppModel
|
||||
|
||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if path.is_file():
|
||||
model_file = path
|
||||
else:
|
||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||
|
||||
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||
return model, tokenizer
|
||||
|
||||
# Quantized model
|
||||
elif shared.args.wbits > 0:
|
||||
|
||||
# Monkey patch
|
||||
if shared.args.monkey_patch:
|
||||
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||
|
||||
model, _ = load_model_llama(model_name)
|
||||
|
||||
# No monkey patch
|
||||
else:
|
||||
from modules.GPTQ_loader import load_quantized
|
||||
|
||||
model = load_quantized(model_name)
|
||||
|
||||
# Custom
|
||||
else:
|
||||
params = {
|
||||
"low_cpu_mem_usage": True,
|
||||
"trust_remote_code": trust_remote_code
|
||||
"trust_remote_code": shared.args.trust_remote_code
|
||||
}
|
||||
|
||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||
@ -188,9 +178,9 @@ def load_model(model_name):
|
||||
|
||||
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(config, trust_remote_code=trust_remote_code)
|
||||
model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
|
||||
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
@ -202,44 +192,77 @@ def load_model(model_name):
|
||||
|
||||
model = LoaderClass.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Hijack attention with xformers
|
||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||
llama_attn_hijack.hijack_llama_attention()
|
||||
return model
|
||||
|
||||
# Loading the tokenizer
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if shared.model_type not in ['llava', 'oasst']:
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
def flexgen_loader(model_name):
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
if tokenizer is None:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
|
||||
# Initialize environment
|
||||
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
||||
|
||||
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
||||
# Offloading policy
|
||||
policy = Policy(1, 1,
|
||||
shared.args.percent[0], shared.args.percent[1],
|
||||
shared.args.percent[2], shared.args.percent[3],
|
||||
shared.args.percent[4], shared.args.percent[5],
|
||||
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
|
||||
cpu_cache_compute=False, attn_sparsity=1.0,
|
||||
compress_weight=shared.args.compress_weight,
|
||||
comp_weight_config=CompressionConfig(
|
||||
num_bits=4, group_size=64,
|
||||
group_dim=0, symmetric=False),
|
||||
compress_cache=False,
|
||||
comp_cache_config=CompressionConfig(
|
||||
num_bits=4, group_size=64,
|
||||
group_dim=2, symmetric=False))
|
||||
|
||||
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
|
||||
return model
|
||||
|
||||
|
||||
def RWKV_loader(model_name):
|
||||
from modules.RWKV import RWKVModel, RWKVTokenizer
|
||||
|
||||
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
||||
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def llamacpp_loader(model_name):
|
||||
from modules.llamacpp_model import LlamaCppModel
|
||||
|
||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if path.is_file():
|
||||
model_file = path
|
||||
else:
|
||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||
|
||||
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def GPTQ_loader(model_name):
|
||||
|
||||
# Monkey patch
|
||||
if shared.args.monkey_patch:
|
||||
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||
|
||||
model, _ = load_model_llama(model_name)
|
||||
|
||||
# No monkey patch
|
||||
else:
|
||||
from modules.GPTQ_loader import load_quantized
|
||||
|
||||
model = load_quantized(model_name)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
for i in range(len(memory_map)):
|
||||
|
@ -456,8 +456,8 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||
with gr.Column():
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
Loading…
Reference in New Issue
Block a user