mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Refactor models.py (#2113)
This commit is contained in:
parent
5cd6dd4287
commit
7584d46c29
@ -9,7 +9,8 @@ As far as I know, DeepSpeed is only available for Linux at the moment.
|
|||||||
1. Install DeepSpeed:
|
1. Install DeepSpeed:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install deepspeed
|
conda install -c conda-forge mpi4py mpich
|
||||||
|
pip install -U deepspeed
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:
|
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:
|
||||||
|
@ -20,9 +20,6 @@ from modules import llama_attn_hijack
|
|||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
if shared.args.flexgen:
|
|
||||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
|
||||||
|
|
||||||
local_rank = None
|
local_rank = None
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
@ -40,6 +37,8 @@ if shared.args.deepspeed:
|
|||||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||||
|
|
||||||
|
|
||||||
|
# Some models require special treatment in various parts of the code.
|
||||||
|
# This function detects those models
|
||||||
def find_model_type(model_name):
|
def find_model_type(model_name):
|
||||||
model_name_lower = model_name.lower()
|
model_name_lower = model_name.lower()
|
||||||
if 'rwkv-' in model_name_lower:
|
if 'rwkv-' in model_name_lower:
|
||||||
@ -72,7 +71,60 @@ def load_model(model_name):
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.model_type = find_model_type(model_name)
|
shared.model_type = find_model_type(model_name)
|
||||||
trust_remote_code = shared.args.trust_remote_code
|
if shared.args.wbits > 0:
|
||||||
|
load_func = GPTQ_loader
|
||||||
|
elif shared.model_type == 'llamacpp':
|
||||||
|
load_func = llamacpp_loader
|
||||||
|
elif shared.model_type == 'rwkv':
|
||||||
|
load_func = RWKV_loader
|
||||||
|
elif shared.args.flexgen:
|
||||||
|
load_func = flexgen_loader
|
||||||
|
else:
|
||||||
|
load_func = huggingface_loader
|
||||||
|
|
||||||
|
output = load_func(model_name)
|
||||||
|
if type(output) is tuple:
|
||||||
|
model, tokenizer = output
|
||||||
|
else:
|
||||||
|
model = output
|
||||||
|
tokenizer = load_tokenizer(model_name, model)
|
||||||
|
|
||||||
|
# Hijack attention with xformers
|
||||||
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
|
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(model_name, model):
|
||||||
|
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
|
# Try to load an universal LLaMA tokenizer
|
||||||
|
if shared.model_type not in ['llava', 'oasst']:
|
||||||
|
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||||
|
if p.exists():
|
||||||
|
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
# Otherwise, load it from the model folder and hope that these
|
||||||
|
# are not outdated tokenizer files.
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def huggingface_loader(model_name):
|
||||||
if shared.model_type == 'chatglm':
|
if shared.model_type == 'chatglm':
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
elif shared.model_type == 'HF_seq2seq':
|
elif shared.model_type == 'HF_seq2seq':
|
||||||
@ -81,37 +133,14 @@ def load_model(model_name):
|
|||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
|
|
||||||
# Load the model in simple 16-bit mode by default
|
# Load the model in simple 16-bit mode by default
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
|
||||||
if torch.has_mps:
|
if torch.has_mps:
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
# FlexGen
|
|
||||||
elif shared.args.flexgen:
|
|
||||||
# Initialize environment
|
|
||||||
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
|
||||||
|
|
||||||
# Offloading policy
|
|
||||||
policy = Policy(1, 1,
|
|
||||||
shared.args.percent[0], shared.args.percent[1],
|
|
||||||
shared.args.percent[2], shared.args.percent[3],
|
|
||||||
shared.args.percent[4], shared.args.percent[5],
|
|
||||||
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
|
|
||||||
cpu_cache_compute=False, attn_sparsity=1.0,
|
|
||||||
compress_weight=shared.args.compress_weight,
|
|
||||||
comp_weight_config=CompressionConfig(
|
|
||||||
num_bits=4, group_size=64,
|
|
||||||
group_dim=0, symmetric=False),
|
|
||||||
compress_cache=False,
|
|
||||||
comp_cache_config=CompressionConfig(
|
|
||||||
num_bits=4, group_size=64,
|
|
||||||
group_dim=2, symmetric=False))
|
|
||||||
|
|
||||||
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
|
|
||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||||
@ -119,50 +148,11 @@ def load_model(model_name):
|
|||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||||
|
|
||||||
# RMKV model (not on HuggingFace)
|
|
||||||
elif shared.model_type == 'rwkv':
|
|
||||||
from modules.RWKV import RWKVModel, RWKVTokenizer
|
|
||||||
|
|
||||||
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
|
||||||
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
# llamacpp model
|
|
||||||
elif shared.model_type == 'llamacpp':
|
|
||||||
from modules.llamacpp_model import LlamaCppModel
|
|
||||||
|
|
||||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
|
||||||
if path.is_file():
|
|
||||||
model_file = path
|
|
||||||
else:
|
|
||||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
|
||||||
|
|
||||||
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
# Quantized model
|
|
||||||
elif shared.args.wbits > 0:
|
|
||||||
|
|
||||||
# Monkey patch
|
|
||||||
if shared.args.monkey_patch:
|
|
||||||
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
|
||||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
|
||||||
|
|
||||||
model, _ = load_model_llama(model_name)
|
|
||||||
|
|
||||||
# No monkey patch
|
|
||||||
else:
|
|
||||||
from modules.GPTQ_loader import load_quantized
|
|
||||||
|
|
||||||
model = load_quantized(model_name)
|
|
||||||
|
|
||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
params = {
|
params = {
|
||||||
"low_cpu_mem_usage": True,
|
"low_cpu_mem_usage": True,
|
||||||
"trust_remote_code": trust_remote_code
|
"trust_remote_code": shared.args.trust_remote_code
|
||||||
}
|
}
|
||||||
|
|
||||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||||
@ -188,9 +178,9 @@ def load_model(model_name):
|
|||||||
|
|
||||||
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||||
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=trust_remote_code)
|
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LoaderClass.from_config(config, trust_remote_code=trust_remote_code)
|
model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
|
||||||
|
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
params['device_map'] = infer_auto_device_map(
|
params['device_map'] = infer_auto_device_map(
|
||||||
@ -202,44 +192,77 @@ def load_model(model_name):
|
|||||||
|
|
||||||
model = LoaderClass.from_pretrained(checkpoint, **params)
|
model = LoaderClass.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
# Hijack attention with xformers
|
return model
|
||||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
|
||||||
llama_attn_hijack.hijack_llama_attention()
|
|
||||||
|
|
||||||
# Loading the tokenizer
|
|
||||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
|
||||||
elif type(model) is transformers.LlamaForCausalLM:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
# Try to load an universal LLaMA tokenizer
|
def flexgen_loader(model_name):
|
||||||
if shared.model_type not in ['llava', 'oasst']:
|
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
|
||||||
if p.exists():
|
|
||||||
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Otherwise, load it from the model folder and hope that these
|
# Initialize environment
|
||||||
# are not outdated tokenizer files.
|
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
|
||||||
try:
|
|
||||||
tokenizer.eos_token_id = 2
|
|
||||||
tokenizer.bos_token_id = 1
|
|
||||||
tokenizer.pad_token_id = 0
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
|
|
||||||
|
|
||||||
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
# Offloading policy
|
||||||
|
policy = Policy(1, 1,
|
||||||
|
shared.args.percent[0], shared.args.percent[1],
|
||||||
|
shared.args.percent[2], shared.args.percent[3],
|
||||||
|
shared.args.percent[4], shared.args.percent[5],
|
||||||
|
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
|
||||||
|
cpu_cache_compute=False, attn_sparsity=1.0,
|
||||||
|
compress_weight=shared.args.compress_weight,
|
||||||
|
comp_weight_config=CompressionConfig(
|
||||||
|
num_bits=4, group_size=64,
|
||||||
|
group_dim=0, symmetric=False),
|
||||||
|
compress_cache=False,
|
||||||
|
comp_cache_config=CompressionConfig(
|
||||||
|
num_bits=4, group_size=64,
|
||||||
|
group_dim=2, symmetric=False))
|
||||||
|
|
||||||
|
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def RWKV_loader(model_name):
|
||||||
|
from modules.RWKV import RWKVModel, RWKVTokenizer
|
||||||
|
|
||||||
|
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
||||||
|
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def llamacpp_loader(model_name):
|
||||||
|
from modules.llamacpp_model import LlamaCppModel
|
||||||
|
|
||||||
|
path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
if path.is_file():
|
||||||
|
model_file = path
|
||||||
|
else:
|
||||||
|
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||||
|
|
||||||
|
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
||||||
|
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def GPTQ_loader(model_name):
|
||||||
|
|
||||||
|
# Monkey patch
|
||||||
|
if shared.args.monkey_patch:
|
||||||
|
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
||||||
|
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||||
|
|
||||||
|
model, _ = load_model_llama(model_name)
|
||||||
|
|
||||||
|
# No monkey patch
|
||||||
|
else:
|
||||||
|
from modules.GPTQ_loader import load_quantized
|
||||||
|
|
||||||
|
model = load_quantized(model_name)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_max_memory_dict():
|
def get_max_memory_dict():
|
||||||
max_memory = {}
|
max_memory = {}
|
||||||
|
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||||
for i in range(len(memory_map)):
|
for i in range(len(memory_map)):
|
||||||
|
@ -456,8 +456,8 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
|
||||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||||
|
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||||
|
|
||||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||||
|
Loading…
Reference in New Issue
Block a user