2023-04-08 02:36:04 +02:00
import gc
2023-02-23 17:28:30 +01:00
import json
2023-05-04 02:43:17 +02:00
import logging
2023-02-23 17:28:30 +01:00
import os
2023-03-19 23:21:41 +01:00
import re
2023-02-23 17:28:30 +01:00
import time
import zipfile
from pathlib import Path
import numpy as np
import torch
2023-02-23 17:42:23 +01:00
import transformers
2023-03-16 17:34:23 +01:00
from accelerate import infer_auto_device_map , init_empty_weights
2023-04-17 00:15:03 +02:00
from transformers import ( AutoConfig , AutoModel , AutoModelForCausalLM ,
2023-04-26 03:39:04 +02:00
AutoModelForSeq2SeqLM , AutoTokenizer ,
BitsAndBytesConfig , LlamaTokenizer )
2023-02-23 18:41:42 +01:00
import modules . shared as shared
2023-04-10 04:08:40 +02:00
from modules import llama_attn_hijack
2023-02-23 17:28:30 +01:00
2023-02-23 17:42:23 +01:00
transformers . logging . set_verbosity_error ( )
2023-02-23 17:28:30 +01:00
if shared . args . flexgen :
2023-03-16 14:18:34 +01:00
from flexgen . flex_opt import CompressionConfig , ExecutionEnv , OptLM , Policy
2023-02-23 17:28:30 +01:00
2023-04-08 02:36:04 +02:00
local_rank = None
2023-02-23 17:28:30 +01:00
if shared . args . deepspeed :
import deepspeed
2023-02-23 18:41:42 +01:00
from transformers . deepspeed import ( HfDeepSpeedConfig ,
is_deepspeed_zero3_enabled )
2023-02-23 17:28:30 +01:00
from modules . deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-04-07 05:15:45 +02:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 17:28:30 +01:00
2023-03-13 18:00:38 +01:00
2023-04-22 19:56:48 +02:00
def find_model_type ( model_name ) :
2023-04-26 06:55:40 +02:00
model_name_lower = model_name . lower ( )
if ' rwkv- ' in model_name_lower :
2023-04-22 19:56:48 +02:00
return ' rwkv '
elif len ( list ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *ggml*.bin ' ) ) ) > 0 :
return ' llamacpp '
2023-04-26 06:55:40 +02:00
elif re . match ( ' .*ggml.* \ .bin ' , model_name_lower ) :
2023-04-22 19:56:48 +02:00
return ' llamacpp '
2023-04-26 06:55:40 +02:00
elif ' chatglm ' in model_name_lower :
2023-04-22 19:56:48 +02:00
return ' chatglm '
2023-04-26 06:55:40 +02:00
elif ' galactica ' in model_name_lower :
2023-04-22 19:56:48 +02:00
return ' galactica '
2023-04-26 06:55:40 +02:00
elif ' llava ' in model_name_lower :
2023-04-24 01:32:22 +02:00
return ' llava '
2023-04-26 06:55:40 +02:00
elif any ( ( k in model_name_lower for k in [ ' gpt4chan ' , ' gpt-4chan ' ] ) ) :
2023-04-22 19:56:48 +02:00
return ' gpt4chan '
else :
2023-05-04 07:01:28 +02:00
config = AutoConfig . from_pretrained ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) , trust_remote_code = shared . args . trust_remote_code )
2023-04-26 03:39:04 +02:00
# Not a "catch all", but fairly accurate
if config . to_dict ( ) . get ( " is_encoder_decoder " , False ) :
return ' HF_seq2seq '
else :
return ' HF_generic '
2023-04-22 19:56:48 +02:00
2023-02-23 17:28:30 +01:00
def load_model ( model_name ) :
2023-05-04 02:43:17 +02:00
logging . info ( f " Loading { model_name } ... " )
2023-02-23 17:28:30 +01:00
t0 = time . time ( )
2023-04-22 19:56:48 +02:00
shared . model_type = find_model_type ( model_name )
2023-05-04 07:01:28 +02:00
trust_remote_code = shared . args . trust_remote_code
2023-04-22 19:56:48 +02:00
if shared . model_type == ' chatglm ' :
2023-04-17 00:15:03 +02:00
LoaderClass = AutoModel
2023-04-26 03:39:04 +02:00
elif shared . model_type == ' HF_seq2seq ' :
LoaderClass = AutoModelForSeq2SeqLM
2023-04-17 00:15:03 +02:00
else :
LoaderClass = AutoModelForCausalLM
2023-02-28 03:03:35 +01:00
2023-04-15 17:54:02 +02:00
# Load the model in simple 16-bit mode by default
2023-04-22 19:56:48 +02:00
if not any ( [ shared . args . cpu , shared . args . load_in_8bit , shared . args . wbits , shared . args . auto_devices , shared . args . disk , shared . args . gpu_memory is not None , shared . args . cpu_memory is not None , shared . args . deepspeed , shared . args . flexgen , shared . model_type in [ ' rwkv ' , ' llamacpp ' ] ] ) :
2023-04-21 05:20:33 +02:00
model = LoaderClass . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 , trust_remote_code = trust_remote_code )
2023-04-15 17:54:02 +02:00
if torch . has_mps :
device = torch . device ( ' mps ' )
model = model . to ( device )
2023-03-18 02:27:26 +01:00
else :
2023-04-15 17:54:02 +02:00
model = model . cuda ( )
2023-03-18 02:27:26 +01:00
2023-02-23 17:28:30 +01:00
# FlexGen
elif shared . args . flexgen :
2023-02-26 20:53:41 +01:00
# Initialize environment
env = ExecutionEnv . create ( shared . args . disk_cache_dir )
2023-02-23 17:28:30 +01:00
# Offloading policy
policy = Policy ( 1 , 1 ,
shared . args . percent [ 0 ] , shared . args . percent [ 1 ] ,
shared . args . percent [ 2 ] , shared . args . percent [ 3 ] ,
shared . args . percent [ 4 ] , shared . args . percent [ 5 ] ,
2023-03-04 05:04:02 +01:00
overlap = True , sep_layer = True , pin_weight = shared . args . pin_weight ,
2023-02-23 17:28:30 +01:00
cpu_cache_compute = False , attn_sparsity = 1.0 ,
compress_weight = shared . args . compress_weight ,
comp_weight_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 0 , symmetric = False ) ,
compress_cache = False ,
comp_cache_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 2 , symmetric = False ) )
2023-04-21 05:20:33 +02:00
model = OptLM ( f " facebook/ { model_name } " , env , shared . args . model_dir , policy )
2023-02-23 17:28:30 +01:00
# DeepSpeed ZeRO-3
elif shared . args . deepspeed :
2023-04-21 05:20:33 +02:00
model = LoaderClass . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } " ) , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 )
2023-02-23 17:28:30 +01:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-04-07 05:15:45 +02:00
model . module . eval ( ) # Inference
2023-05-04 02:43:17 +02:00
logging . info ( f " DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } " )
2023-02-23 17:28:30 +01:00
2023-02-28 03:03:35 +01:00
# RMKV model (not on HuggingFace)
2023-04-22 19:56:48 +02:00
elif shared . model_type == ' rwkv ' :
2023-03-06 12:45:49 +01:00
from modules . RWKV import RWKVModel , RWKVTokenizer
2023-02-28 03:03:35 +01:00
2023-03-25 01:30:18 +01:00
model = RWKVModel . from_pretrained ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) , dtype = " fp32 " if shared . args . cpu else " bf16 " if shared . args . bf16 else " fp16 " , device = " cpu " if shared . args . cpu else " cuda " )
2023-03-28 04:42:29 +02:00
tokenizer = RWKVTokenizer . from_pretrained ( Path ( shared . args . model_dir ) )
2023-03-01 16:08:55 +01:00
2023-03-06 12:45:49 +01:00
return model , tokenizer
2023-02-28 03:03:35 +01:00
2023-04-17 15:47:26 +02:00
# llamacpp model
2023-04-22 19:56:48 +02:00
elif shared . model_type == ' llamacpp ' :
2023-05-02 23:25:28 +02:00
from modules . llamacpp_model import LlamaCppModel
2023-04-17 15:47:26 +02:00
2023-04-22 19:56:48 +02:00
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if path . is_file ( ) :
model_file = path
else :
model_file = list ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *ggml*.bin ' ) ) [ 0 ]
2023-04-17 15:47:26 +02:00
2023-05-04 02:43:17 +02:00
logging . info ( f " llama.cpp weights detected: { model_file } \n " )
2023-04-17 15:47:26 +02:00
model , tokenizer = LlamaCppModel . from_pretrained ( model_file )
return model , tokenizer
2023-03-13 18:00:38 +01:00
# Quantized model
2023-03-26 05:11:33 +02:00
elif shared . args . wbits > 0 :
2023-03-10 13:29:09 +01:00
2023-04-17 04:26:52 +02:00
# Monkey patch
if shared . args . monkey_patch :
2023-05-04 03:06:46 +02:00
logging . warning ( " Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope. " )
2023-04-17 04:26:52 +02:00
from modules . monkey_patch_gptq_lora import load_model_llama
2023-04-26 04:18:11 +02:00
model , _ = load_model_llama ( model_name )
2023-04-17 04:26:52 +02:00
# No monkey patch
else :
from modules . GPTQ_loader import load_quantized
model = load_quantized ( model_name )
2023-03-09 19:50:26 +01:00
2023-02-23 17:28:30 +01:00
# Custom
else :
2023-03-16 16:42:53 +01:00
params = { " low_cpu_mem_usage " : True }
2023-03-18 02:56:46 +01:00
if not any ( ( shared . args . cpu , torch . cuda . is_available ( ) , torch . has_mps ) ) :
2023-05-04 03:06:46 +02:00
logging . warning ( " torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode. " )
2023-02-23 17:28:30 +01:00
shared . args . cpu = True
if shared . args . cpu :
2023-03-16 16:42:53 +01:00
params [ " torch_dtype " ] = torch . float32
2023-02-23 17:28:30 +01:00
else :
2023-03-16 16:42:53 +01:00
params [ " device_map " ] = ' auto '
2023-04-17 00:15:03 +02:00
params [ " trust_remote_code " ] = trust_remote_code
2023-03-16 22:22:16 +01:00
if shared . args . load_in_8bit and any ( ( shared . args . auto_devices , shared . args . gpu_memory ) ) :
2023-03-16 16:42:53 +01:00
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True , llm_int8_enable_fp32_cpu_offload = True )
2023-03-16 22:22:16 +01:00
elif shared . args . load_in_8bit :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True )
2023-03-16 16:42:53 +01:00
elif shared . args . bf16 :
params [ " torch_dtype " ] = torch . bfloat16
else :
params [ " torch_dtype " ] = torch . float16
2023-02-23 17:28:30 +01:00
if shared . args . gpu_memory :
2023-04-07 05:15:45 +02:00
memory_map = list ( map ( lambda x : x . strip ( ) , shared . args . gpu_memory ) )
2023-03-19 23:21:41 +01:00
max_cpu_memory = shared . args . cpu_memory . strip ( ) if shared . args . cpu_memory is not None else ' 99GiB '
2023-03-16 17:34:23 +01:00
max_memory = { }
for i in range ( len ( memory_map ) ) :
2023-03-19 23:21:41 +01:00
max_memory [ i ] = f ' { memory_map [ i ] } GiB ' if not re . match ( ' .*ib$ ' , memory_map [ i ] . lower ( ) ) else memory_map [ i ]
2023-05-04 02:43:17 +02:00
2023-03-19 23:21:41 +01:00
max_memory [ ' cpu ' ] = max_cpu_memory
2023-03-16 16:42:53 +01:00
params [ ' max_memory ' ] = max_memory
2023-03-16 22:22:16 +01:00
elif shared . args . auto_devices :
2023-04-07 05:15:45 +02:00
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
2023-03-16 16:42:53 +01:00
if total_mem - suggestion < 800 :
2023-02-23 17:28:30 +01:00
suggestion - = 1000
2023-04-07 05:15:45 +02:00
2023-05-04 02:43:17 +02:00
suggestion = int ( round ( suggestion / 1000 ) )
logging . warning ( f " \033 [1;32;1mAuto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. \n You can manually set other values. \033 [0;37;0m " )
2023-03-16 17:34:23 +01:00
max_memory = { 0 : f ' { suggestion } GiB ' , ' cpu ' : f ' { shared . args . cpu_memory or 99 } GiB ' }
2023-03-16 16:42:53 +01:00
params [ ' max_memory ' ] = max_memory
2023-02-23 17:28:30 +01:00
2023-03-16 16:42:53 +01:00
if shared . args . disk :
params [ " offload_folder " ] = shared . args . disk_cache_dir
2023-04-21 05:20:33 +02:00
checkpoint = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-03-16 16:42:53 +01:00
if shared . args . load_in_8bit and params . get ( ' max_memory ' , None ) is not None and params [ ' device_map ' ] == ' auto ' :
config = AutoConfig . from_pretrained ( checkpoint )
with init_empty_weights ( ) :
2023-04-17 00:15:03 +02:00
model = LoaderClass . from_config ( config )
2023-05-04 02:43:17 +02:00
2023-03-16 16:42:53 +01:00
model . tie_weights ( )
params [ ' device_map ' ] = infer_auto_device_map (
2023-04-07 05:15:45 +02:00
model ,
dtype = torch . int8 ,
2023-03-16 16:42:53 +01:00
max_memory = params [ ' max_memory ' ] ,
2023-04-07 05:15:45 +02:00
no_split_module_classes = model . _no_split_modules
2023-03-16 16:42:53 +01:00
)
2023-04-17 00:15:03 +02:00
model = LoaderClass . from_pretrained ( checkpoint , * * params )
2023-02-23 17:28:30 +01:00
2023-04-10 04:08:40 +02:00
# Hijack attention with xformers
if any ( ( shared . args . xformers , shared . args . sdp_attention ) ) :
llama_attn_hijack . hijack_llama_attention ( )
2023-02-23 17:28:30 +01:00
# Loading the tokenizer
2023-04-22 19:56:48 +02:00
if shared . model_type == ' gpt4chan ' and Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) . exists ( ) :
2023-03-25 01:30:18 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) )
2023-04-06 21:04:03 +02:00
elif type ( model ) is transformers . LlamaForCausalLM :
2023-04-20 02:23:51 +02:00
tokenizer = None
# Try to load an universal LLaMA tokenizer
2023-04-24 01:32:22 +02:00
if shared . model_type != ' llava ' :
for p in [ Path ( f " { shared . args . model_dir } /llama-tokenizer/ " ) , Path ( f " { shared . args . model_dir } /oobabooga_llama-tokenizer/ " ) ] :
if p . exists ( ) :
2023-05-04 02:43:17 +02:00
logging . info ( f " Loading the universal LLaMA tokenizer from { p } ... " )
2023-04-24 01:32:22 +02:00
tokenizer = LlamaTokenizer . from_pretrained ( p , clean_up_tokenization_spaces = True )
break
2023-04-20 02:23:51 +02:00
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.
if tokenizer is None :
2023-04-21 05:20:33 +02:00
tokenizer = LlamaTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } / " ) , clean_up_tokenization_spaces = True )
2023-04-20 02:23:51 +02:00
try :
tokenizer . eos_token_id = 2
tokenizer . bos_token_id = 1
tokenizer . pad_token_id = 0
except :
pass
2023-02-23 17:28:30 +01:00
else :
2023-04-21 05:20:33 +02:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } / " ) , trust_remote_code = trust_remote_code )
2023-02-23 17:28:30 +01:00
2023-05-04 02:43:17 +02:00
logging . info ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2023-02-23 17:28:30 +01:00
return model , tokenizer
2023-04-07 05:15:45 +02:00
2023-04-08 02:36:04 +02:00
def clear_torch_cache ( ) :
gc . collect ( )
if not shared . args . cpu :
torch . cuda . empty_cache ( )
def unload_model ( ) :
shared . model = shared . tokenizer = None
clear_torch_cache ( )
def reload_model ( ) :
2023-04-08 02:37:41 +02:00
unload_model ( )
2023-04-08 02:36:04 +02:00
shared . model , shared . tokenizer = load_model ( shared . model_name )
2023-02-23 17:28:30 +01:00
def load_soft_prompt ( name ) :
if name == ' None ' :
shared . soft_prompt = False
shared . soft_prompt_tensor = None
else :
with zipfile . ZipFile ( Path ( f ' softprompts/ { name } .zip ' ) ) as zf :
zf . extract ( ' tensor.npy ' )
zf . extract ( ' meta.json ' )
j = json . loads ( open ( ' meta.json ' , ' r ' ) . read ( ) )
2023-05-04 02:43:17 +02:00
logging . info ( f " \n Loading the softprompt \" { name } \" . " )
2023-02-23 17:28:30 +01:00
for field in j :
if field != ' name ' :
if type ( j [ field ] ) is list :
2023-05-04 02:43:17 +02:00
logging . info ( f " { field } : { ' , ' . join ( j [ field ] ) } " )
2023-02-23 17:28:30 +01:00
else :
2023-05-04 02:43:17 +02:00
logging . info ( f " { field } : { j [ field ] } " )
logging . info ( )
2023-02-23 17:28:30 +01:00
tensor = np . load ( ' tensor.npy ' )
Path ( ' tensor.npy ' ) . unlink ( )
Path ( ' meta.json ' ) . unlink ( )
2023-05-04 02:43:17 +02:00
2023-02-23 17:28:30 +01:00
tensor = torch . Tensor ( tensor ) . to ( device = shared . model . device , dtype = shared . model . dtype )
tensor = torch . reshape ( tensor , ( 1 , tensor . shape [ 0 ] , tensor . shape [ 1 ] ) )
shared . soft_prompt = True
shared . soft_prompt_tensor = tensor
return name