2023-04-08 03:36:04 +03:00
import gc
2023-12-06 00:01:01 -03:00
import logging
2023-02-23 13:28:30 -03:00
import os
2023-03-19 19:21:41 -03:00
import re
2023-02-23 13:28:30 -03:00
import time
2023-09-24 20:23:05 -07:00
import traceback
2023-02-23 13:28:30 -03:00
from pathlib import Path
import torch
2023-02-23 13:42:23 -03:00
import transformers
2023-10-26 20:26:25 -07:00
from accelerate import infer_auto_device_map , init_empty_weights
from accelerate . utils import is_ccl_available , is_xpu_available
2023-06-25 01:44:36 -03:00
from transformers import (
AutoConfig ,
AutoModel ,
AutoModelForCausalLM ,
AutoModelForSeq2SeqLM ,
AutoTokenizer ,
2023-09-24 20:03:11 -07:00
BitsAndBytesConfig ,
GPTQConfig
2023-06-25 01:44:36 -03:00
)
2023-02-23 14:41:42 -03:00
import modules . shared as shared
2023-09-11 18:49:30 -03:00
from modules import RoPE , llama_attn_hijack , sampler_hijack
2023-05-21 22:42:34 -03:00
from modules . logging_colors import logger
2023-09-11 18:49:30 -03:00
from modules . models_settings import get_model_metadata
2023-12-06 00:01:01 -03:00
from modules . relative_imports import RelativeImport
2023-02-23 13:28:30 -03:00
2023-02-23 13:42:23 -03:00
transformers . logging . set_verbosity_error ( )
2023-04-08 03:36:04 +03:00
local_rank = None
2023-02-23 13:28:30 -03:00
if shared . args . deepspeed :
import deepspeed
2023-06-25 01:44:36 -03:00
from transformers . deepspeed import (
HfDeepSpeedConfig ,
is_deepspeed_zero3_enabled
)
2023-02-23 14:41:42 -03:00
2023-02-23 13:28:30 -03:00
from modules . deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
2023-10-27 08:09:51 +05:30
if is_xpu_available ( ) and is_ccl_available ( ) :
torch . xpu . set_device ( local_rank )
deepspeed . init_distributed ( backend = " ccl " )
else :
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
2023-02-23 13:28:30 -03:00
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-04-07 00:15:45 -03:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 13:28:30 -03:00
2023-05-30 08:40:01 +08:00
sampler_hijack . hijack_samplers ( )
2023-03-13 20:00:38 +03:00
2023-06-16 19:00:37 -03:00
def load_model ( model_name , loader = None ) :
2023-05-21 22:42:34 -03:00
logger . info ( f " Loading { model_name } ... " )
2023-02-23 13:28:30 -03:00
t0 = time . time ( )
2023-06-16 19:00:37 -03:00
shared . is_seq2seq = False
2023-12-08 06:35:23 -08:00
shared . model_name = model_name
2023-06-16 19:00:37 -03:00
load_func_map = {
' Transformers ' : huggingface_loader ,
' AutoGPTQ ' : AutoGPTQ_loader ,
' GPTQ-for-LLaMa ' : GPTQ_loader ,
' llama.cpp ' : llamacpp_loader ,
2023-07-16 02:21:13 -03:00
' llamacpp_HF ' : llamacpp_HF_loader ,
2023-06-16 20:35:38 -03:00
' RWKV ' : RWKV_loader ,
2023-06-22 02:31:42 +08:00
' ExLlama ' : ExLlama_loader ,
2023-08-11 17:41:33 +00:00
' ExLlama_HF ' : ExLlama_HF_loader ,
2023-09-12 14:33:07 -03:00
' ExLlamav2 ' : ExLlamav2_loader ,
' ExLlamav2_HF ' : ExLlamav2_HF_loader ,
2023-08-11 17:41:33 +00:00
' ctransformers ' : ctransformers_loader ,
2023-10-05 16:19:18 +00:00
' AutoAWQ ' : AutoAWQ_loader ,
2023-12-06 00:01:01 -03:00
' QuIP# ' : QuipSharp_loader ,
2023-06-16 19:00:37 -03:00
}
2023-11-06 02:38:29 -03:00
metadata = get_model_metadata ( model_name )
2023-06-16 19:00:37 -03:00
if loader is None :
if shared . args . loader is not None :
loader = shared . args . loader
else :
2023-11-06 02:38:29 -03:00
loader = metadata [ ' loader ' ]
2023-06-16 19:00:37 -03:00
if loader is None :
logger . error ( ' The path to the model does not exist. Exiting. ' )
2023-11-07 20:58:06 -08:00
raise ValueError
2023-05-16 19:52:22 -03:00
2023-06-16 19:00:37 -03:00
shared . args . loader = loader
output = load_func_map [ loader ] ( model_name )
2023-05-16 19:52:22 -03:00
if type ( output ) is tuple :
model , tokenizer = output
else :
model = output
2023-05-19 11:20:08 -03:00
if model is None :
return None , None
else :
tokenizer = load_tokenizer ( model_name , model )
2023-05-16 19:52:22 -03:00
# Hijack attention with xformers
if any ( ( shared . args . xformers , shared . args . sdp_attention ) ) :
llama_attn_hijack . hijack_llama_attention ( )
2023-11-06 02:38:29 -03:00
shared . settings . update ( { k : v for k , v in metadata . items ( ) if k in shared . settings } )
2023-11-15 16:00:51 -08:00
if loader . lower ( ) . startswith ( ' exllama ' ) :
shared . settings [ ' truncation_length ' ] = shared . args . max_seq_len
elif loader in [ ' llama.cpp ' , ' llamacpp_HF ' , ' ctransformers ' ] :
shared . settings [ ' truncation_length ' ] = shared . args . n_ctx
2023-11-30 12:00:32 -08:00
logger . info ( f " LOADER: { loader } " )
2023-11-15 16:13:36 -08:00
logger . info ( f " TRUNCATION LENGTH: { shared . settings [ ' truncation_length ' ] } " )
2023-12-12 17:23:14 -03:00
logger . info ( f " INSTRUCTION TEMPLATE: { metadata [ ' instruction_template ' ] } " )
2023-10-10 22:20:49 -03:00
logger . info ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2023-05-16 19:52:22 -03:00
return model , tokenizer
def load_tokenizer ( model_name , model ) :
2023-05-17 15:52:23 -03:00
tokenizer = None
2023-07-04 19:43:19 -07:00
path_to_model = Path ( f " { shared . args . model_dir } / { model_name } / " )
2023-06-16 19:00:37 -03:00
if any ( s in model_name . lower ( ) for s in [ ' gpt-4chan ' , ' gpt4chan ' ] ) and Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) . exists ( ) :
2023-05-16 19:52:22 -03:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) )
2023-07-04 19:43:19 -07:00
elif path_to_model . exists ( ) :
2023-11-16 19:45:05 -08:00
if shared . args . no_use_fast :
logger . info ( ' Loading the tokenizer with use_fast=False. ' )
2023-09-25 12:19:43 -07:00
tokenizer = AutoTokenizer . from_pretrained (
path_to_model ,
trust_remote_code = shared . args . trust_remote_code ,
2023-11-16 19:45:05 -08:00
use_fast = not shared . args . no_use_fast
2023-09-25 12:19:43 -07:00
)
2023-07-04 19:43:19 -07:00
2023-05-16 19:52:22 -03:00
return tokenizer
def huggingface_loader ( model_name ) :
2023-09-24 20:03:11 -07:00
2023-06-16 19:00:37 -03:00
path_to_model = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-09-24 20:23:05 -07:00
params = {
' low_cpu_mem_usage ' : True ,
' trust_remote_code ' : shared . args . trust_remote_code ,
2023-11-02 20:20:54 +01:00
' torch_dtype ' : torch . bfloat16 if shared . args . bf16 else torch . float16 ,
' use_safetensors ' : True if shared . args . force_safetensors else None
2023-09-24 20:23:05 -07:00
}
2023-11-05 00:59:33 +08:00
if shared . args . use_flash_attention_2 :
params [ ' use_flash_attention_2 ' ] = True
2023-09-24 20:23:05 -07:00
config = AutoConfig . from_pretrained ( path_to_model , trust_remote_code = params [ ' trust_remote_code ' ] )
2023-09-24 20:03:11 -07:00
2023-06-16 19:00:37 -03:00
if ' chatglm ' in model_name . lower ( ) :
2023-04-16 22:15:03 +00:00
LoaderClass = AutoModel
else :
2023-09-24 20:23:05 -07:00
if config . to_dict ( ) . get ( ' is_encoder_decoder ' , False ) :
2023-06-16 19:00:37 -03:00
LoaderClass = AutoModelForSeq2SeqLM
shared . is_seq2seq = True
else :
LoaderClass = AutoModelForCausalLM
2023-02-27 23:03:35 -03:00
2023-04-15 12:54:02 -03:00
# Load the model in simple 16-bit mode by default
2023-12-15 06:46:13 -08:00
if not any ( [ shared . args . cpu , shared . args . load_in_8bit , shared . args . load_in_4bit , shared . args . auto_devices , shared . args . disk , shared . args . deepspeed , shared . args . gpu_memory is not None , shared . args . cpu_memory is not None , shared . args . compress_pos_emb > 1 , shared . args . alpha_value > 1 , shared . args . disable_exllama , shared . args . disable_exllamav2 ] ) :
2023-09-24 20:23:05 -07:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2023-07-18 08:27:18 +08:00
if torch . backends . mps . is_available ( ) :
2023-04-15 12:54:02 -03:00
device = torch . device ( ' mps ' )
model = model . to ( device )
2023-10-27 08:09:51 +05:30
elif is_xpu_available ( ) :
device = torch . device ( " xpu " )
model = model . to ( device )
2023-03-18 02:27:26 +01:00
else :
2023-04-15 12:54:02 -03:00
model = model . cuda ( )
2023-03-18 02:27:26 +01:00
2023-02-23 13:28:30 -03:00
# DeepSpeed ZeRO-3
elif shared . args . deepspeed :
2023-09-24 20:23:05 -07:00
model = LoaderClass . from_pretrained ( path_to_model , torch_dtype = params [ ' torch_dtype ' ] )
2023-02-23 13:28:30 -03:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-04-07 00:15:45 -03:00
model . module . eval ( ) # Inference
2023-09-24 20:23:05 -07:00
logger . info ( f ' DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } ' )
2023-02-23 13:28:30 -03:00
2023-09-24 20:23:05 -07:00
# Load with quantization and/or offloading
2023-02-23 13:28:30 -03:00
else :
2023-10-27 08:09:51 +05:30
if not any ( ( shared . args . cpu , torch . cuda . is_available ( ) , is_xpu_available ( ) , torch . backends . mps . is_available ( ) ) ) :
logger . warning ( ' torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode. ' )
2023-02-23 13:28:30 -03:00
shared . args . cpu = True
if shared . args . cpu :
2023-09-24 20:23:05 -07:00
params [ ' torch_dtype ' ] = torch . float32
2023-02-23 13:28:30 -03:00
else :
2023-09-24 20:23:05 -07:00
params [ ' device_map ' ] = ' auto '
params [ ' max_memory ' ] = get_max_memory_dict ( )
2023-05-25 01:14:13 -03:00
if shared . args . load_in_4bit :
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
' load_in_4bit ' : True ,
' bnb_4bit_compute_dtype ' : eval ( " torch. {} " . format ( shared . args . compute_dtype ) ) if shared . args . compute_dtype in [ " bfloat16 " , " float16 " , " float32 " ] else None ,
' bnb_4bit_quant_type ' : shared . args . quant_type ,
' bnb_4bit_use_double_quant ' : shared . args . use_double_quant ,
}
2023-09-24 20:23:05 -07:00
logger . info ( ' Using the following 4-bit params: ' + str ( quantization_config_params ) )
2023-05-25 01:14:13 -03:00
params [ ' quantization_config ' ] = BitsAndBytesConfig ( * * quantization_config_params )
2023-03-16 18:22:16 -03:00
elif shared . args . load_in_8bit :
2023-09-24 20:03:11 -07:00
if any ( ( shared . args . auto_devices , shared . args . gpu_memory ) ) :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True , llm_int8_enable_fp32_cpu_offload = True )
else :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True )
2023-02-23 13:28:30 -03:00
2023-09-24 20:23:05 -07:00
if params [ ' max_memory ' ] is not None :
with init_empty_weights ( ) :
model = LoaderClass . from_config ( config , trust_remote_code = params [ ' trust_remote_code ' ] )
model . tie_weights ( )
params [ ' device_map ' ] = infer_auto_device_map (
model ,
dtype = torch . int8 ,
max_memory = params [ ' max_memory ' ] ,
no_split_module_classes = model . _no_split_modules
)
2023-03-16 18:42:53 +03:00
if shared . args . disk :
2023-09-24 20:23:05 -07:00
params [ ' offload_folder ' ] = shared . args . disk_cache_dir
2023-03-16 18:42:53 +03:00
2023-12-15 06:46:13 -08:00
if shared . args . disable_exllama or shared . args . disable_exllamav2 :
2023-09-24 20:03:11 -07:00
try :
2023-12-15 06:46:13 -08:00
gptq_config = GPTQConfig (
bits = config . quantization_config . get ( ' bits ' , 4 ) ,
disable_exllama = shared . args . disable_exllama ,
disable_exllamav2 = shared . args . disable_exllamav2 ,
)
2023-09-24 20:03:11 -07:00
params [ ' quantization_config ' ] = gptq_config
2023-12-15 06:46:13 -08:00
logger . info ( f ' Loading with disable_exllama= { shared . args . disable_exllama } and disable_exllamav2= { shared . args . disable_exllamav2 } . ' )
2023-09-24 20:03:11 -07:00
except :
2023-09-24 20:23:05 -07:00
exc = traceback . format_exc ( )
2023-09-24 20:03:11 -07:00
logger . error ( ' Failed to disable exllama. Does the config.json for this model contain the necessary quantization info? ' )
2023-09-24 20:23:05 -07:00
print ( exc )
2023-03-16 18:42:53 +03:00
2023-08-08 21:24:28 -07:00
if shared . args . compress_pos_emb > 1 :
params [ ' rope_scaling ' ] = { ' type ' : ' linear ' , ' factor ' : shared . args . compress_pos_emb }
elif shared . args . alpha_value > 1 :
2023-08-25 06:53:37 -07:00
params [ ' rope_scaling ' ] = { ' type ' : ' dynamic ' , ' factor ' : RoPE . get_alpha_value ( shared . args . alpha_value , shared . args . rope_freq_base ) }
2023-08-08 21:24:28 -07:00
2023-09-24 20:03:11 -07:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2023-02-23 13:28:30 -03:00
2023-05-16 19:52:22 -03:00
return model
2023-04-09 22:08:40 -04:00
2023-04-19 21:23:51 -03:00
2023-05-16 19:52:22 -03:00
def llamacpp_loader ( model_name ) :
from modules . llamacpp_model import LlamaCppModel
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if path . is_file ( ) :
model_file = path
2023-02-23 13:28:30 -03:00
else :
2023-09-11 07:30:56 -07:00
model_file = list ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *.gguf ' ) ) [ 0 ]
2023-02-23 13:28:30 -03:00
2023-08-11 17:41:33 +00:00
logger . info ( f " llama.cpp weights detected: { model_file } " )
2023-05-16 19:52:22 -03:00
model , tokenizer = LlamaCppModel . from_pretrained ( model_file )
2023-02-23 13:28:30 -03:00
return model , tokenizer
2023-04-07 00:15:45 -03:00
2023-07-16 02:21:13 -03:00
def llamacpp_HF_loader ( model_name ) :
from modules . llamacpp_hf import LlamacppHF
2023-09-15 17:38:38 +02:00
for fname in [ model_name , " oobabooga_llama-tokenizer " , " llama-tokenizer " ] :
2023-07-16 02:21:13 -03:00
path = Path ( f ' { shared . args . model_dir } / { fname } ' )
2023-09-15 17:38:38 +02:00
if all ( ( path / file ) . exists ( ) for file in [ ' tokenizer_config.json ' , ' special_tokens_map.json ' , ' tokenizer.model ' ] ) :
logger . info ( f ' Using tokenizer from: { path } ' )
2023-07-16 02:21:13 -03:00
break
else :
logger . error ( " Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer. " )
return None , None
2023-11-16 19:45:05 -08:00
if shared . args . no_use_fast :
logger . info ( ' Loading the tokenizer with use_fast=False. ' )
2023-09-25 12:19:43 -07:00
2023-07-16 02:21:13 -03:00
tokenizer = AutoTokenizer . from_pretrained (
path ,
trust_remote_code = shared . args . trust_remote_code ,
2023-11-16 19:45:05 -08:00
use_fast = not shared . args . no_use_fast
2023-07-16 02:21:13 -03:00
)
model = LlamacppHF . from_pretrained ( model_name )
return model , tokenizer
2023-08-11 17:41:33 +00:00
def ctransformers_loader ( model_name ) :
from modules . ctransformers_model import CtransformersModel
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
ctrans = CtransformersModel ( )
if ctrans . model_type_is_auto ( ) :
model_file = path
else :
if path . is_file ( ) :
model_file = path
else :
2023-08-25 14:33:04 +00:00
entries = Path ( f ' { shared . args . model_dir } / { model_name } ' )
gguf = list ( entries . glob ( ' *.gguf ' ) )
bin = list ( entries . glob ( ' *.bin ' ) )
if len ( gguf ) > 0 :
model_file = gguf [ 0 ]
elif len ( bin ) > 0 :
model_file = bin [ 0 ]
else :
logger . error ( " Could not find a model for ctransformers. " )
return None , None
2023-08-11 17:41:33 +00:00
logger . info ( f ' ctransformers weights detected: { model_file } ' )
model , tokenizer = ctrans . from_pretrained ( model_file )
return model , tokenizer
2023-10-10 19:03:09 -07:00
2023-10-05 16:19:18 +00:00
def AutoAWQ_loader ( model_name ) :
2023-10-10 19:03:09 -07:00
from awq import AutoAWQForCausalLM
2023-10-05 16:19:18 +00:00
2023-10-10 19:03:09 -07:00
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-10-05 16:19:18 +00:00
2023-10-10 19:03:09 -07:00
model = AutoAWQForCausalLM . from_quantized (
quant_path = model_dir ,
max_new_tokens = shared . args . max_seq_len ,
trust_remote_code = shared . args . trust_remote_code ,
fuse_layers = not shared . args . no_inject_fused_attention ,
max_memory = get_max_memory_dict ( ) ,
2023-10-23 20:45:43 -07:00
batch_size = 1 ,
2023-10-10 19:03:09 -07:00
safetensors = any ( model_dir . glob ( ' *.safetensors ' ) ) ,
)
2023-10-05 16:19:18 +00:00
2023-10-10 19:03:09 -07:00
return model
2023-10-05 16:19:18 +00:00
2023-08-11 17:41:33 +00:00
2023-12-06 00:01:01 -03:00
def QuipSharp_loader ( model_name ) :
try :
with RelativeImport ( " repositories/quip-sharp " ) :
from lib . utils . unsafe_import import model_from_hf_path
except :
logger . error (
" \n QuIP# has not been found. It must be installed manually for now. \n "
" For instructions on how to do that, please consult: \n "
" https://github.com/oobabooga/text-generation-webui/pull/4803 \n "
)
return None , None
# This fixes duplicate logging messages after the import above.
handlers = logging . getLogger ( ) . handlers
if len ( handlers ) > 1 :
logging . getLogger ( ) . removeHandler ( handlers [ 1 ] )
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if not all ( ( model_dir / file ) . exists ( ) for file in [ ' tokenizer_config.json ' , ' special_tokens_map.json ' , ' tokenizer.model ' ] ) :
logger . error ( f " Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into { model_dir } : special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json. " )
return None , None
model , model_str = model_from_hf_path (
model_dir ,
use_cuda_graph = False ,
use_flash_attn = not shared . args . no_flash_attn
)
return model
2023-05-16 19:52:22 -03:00
def GPTQ_loader ( model_name ) :
# Monkey patch
if shared . args . monkey_patch :
2023-05-29 22:40:54 -03:00
logger . warning ( " Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope. " )
2023-05-16 19:52:22 -03:00
from modules . monkey_patch_gptq_lora import load_model_llama
model , _ = load_model_llama ( model_name )
# No monkey patch
else :
2023-05-17 11:23:13 -03:00
import modules . GPTQ_loader
2023-05-16 19:52:22 -03:00
2023-05-17 11:23:13 -03:00
model = modules . GPTQ_loader . load_quantized ( model_name )
2023-05-16 19:52:22 -03:00
return model
2023-05-17 11:12:12 -03:00
def AutoGPTQ_loader ( model_name ) :
2023-05-17 11:23:13 -03:00
import modules . AutoGPTQ_loader
2023-05-17 11:12:12 -03:00
2023-05-17 11:23:13 -03:00
return modules . AutoGPTQ_loader . load_quantized ( model_name )
2023-05-17 11:12:12 -03:00
2023-06-16 20:35:38 -03:00
def ExLlama_loader ( model_name ) :
from modules . exllama import ExllamaModel
model , tokenizer = ExllamaModel . from_pretrained ( model_name )
return model , tokenizer
2023-06-22 02:31:42 +08:00
def ExLlama_HF_loader ( model_name ) :
from modules . exllama_hf import ExllamaHF
return ExllamaHF . from_pretrained ( model_name )
2023-09-12 14:33:07 -03:00
def ExLlamav2_loader ( model_name ) :
from modules . exllamav2 import Exllamav2Model
model , tokenizer = Exllamav2Model . from_pretrained ( model_name )
return model , tokenizer
def ExLlamav2_HF_loader ( model_name ) :
from modules . exllamav2_hf import Exllamav2HF
return Exllamav2HF . from_pretrained ( model_name )
2023-09-26 17:43:39 -07:00
def RWKV_loader ( model_name ) :
'''
This loader is not currently maintained as RWKV can now be loaded
through the transformers library .
'''
from modules . RWKV import RWKVModel , RWKVTokenizer
2023-10-27 08:09:51 +05:30
model = RWKVModel . from_pretrained (
Path ( f ' { shared . args . model_dir } / { model_name } ' ) ,
dtype = " fp32 " if shared . args . cpu else " bf16 " if shared . args . bf16 else " fp16 " ,
device = " cpu " if shared . args . cpu else " xpu " if is_xpu_available ( ) else " cuda "
)
2023-09-26 17:43:39 -07:00
tokenizer = RWKVTokenizer . from_pretrained ( Path ( shared . args . model_dir ) )
return model , tokenizer
2023-05-15 19:38:27 -03:00
def get_max_memory_dict ( ) :
2023-11-15 16:01:54 -08:00
max_memory = { }
2023-11-15 16:04:02 -08:00
max_cpu_memory = shared . args . cpu_memory . strip ( ) if shared . args . cpu_memory is not None else ' 99GiB '
2023-05-15 19:38:27 -03:00
if shared . args . gpu_memory :
memory_map = list ( map ( lambda x : x . strip ( ) , shared . args . gpu_memory ) )
for i in range ( len ( memory_map ) ) :
max_memory [ i ] = f ' { memory_map [ i ] } GiB ' if not re . match ( ' .*ib$ ' , memory_map [ i ] . lower ( ) ) else memory_map [ i ]
2023-11-15 16:01:54 -08:00
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-15 19:38:27 -03:00
# If --auto-devices is provided standalone, try to get a reasonable value
# for the maximum memory of device :0
elif shared . args . auto_devices :
2023-10-27 08:09:51 +05:30
if is_xpu_available ( ) :
total_mem = ( torch . xpu . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
else :
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
2023-11-15 16:00:51 -08:00
2023-05-15 19:38:27 -03:00
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
if total_mem - suggestion < 800 :
suggestion - = 1000
suggestion = int ( round ( suggestion / 1000 ) )
2023-05-21 22:42:34 -03:00
logger . warning ( f " Auto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. You can manually set other values. " )
2023-11-15 16:04:02 -08:00
max_memory [ 0 ] = f ' { suggestion } GiB '
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-15 19:38:27 -03:00
2023-05-17 11:12:12 -03:00
return max_memory if len ( max_memory ) > 0 else None
2023-05-15 19:38:27 -03:00
2023-04-08 03:36:04 +03:00
def clear_torch_cache ( ) :
gc . collect ( )
if not shared . args . cpu :
2023-10-27 08:09:51 +05:30
if is_xpu_available ( ) :
torch . xpu . empty_cache ( )
else :
torch . cuda . empty_cache ( )
2023-04-08 03:36:04 +03:00
def unload_model ( ) :
shared . model = shared . tokenizer = None
2023-07-03 16:39:06 -04:00
shared . lora_names = [ ]
2023-07-12 14:29:43 -04:00
shared . model_dirty_from_training = False
2023-04-08 03:36:04 +03:00
clear_torch_cache ( )
def reload_model ( ) :
2023-04-07 21:37:41 -03:00
unload_model ( )
2023-04-08 03:36:04 +03:00
shared . model , shared . tokenizer = load_model ( shared . model_name )