2023-04-08 02:36:04 +02:00
import gc
2023-02-23 17:28:30 +01:00
import os
2024-04-06 03:40:02 +02:00
import pprint
2023-03-19 23:21:41 +01:00
import re
2023-02-23 17:28:30 +01:00
import time
2023-09-25 05:23:05 +02:00
import traceback
2023-02-23 17:28:30 +01:00
from pathlib import Path
import torch
2023-02-23 17:42:23 +01:00
import transformers
2023-10-27 05:26:25 +02:00
from accelerate import infer_auto_device_map , init_empty_weights
2024-04-11 23:42:20 +02:00
from accelerate . utils import (
is_ccl_available ,
is_npu_available ,
is_xpu_available
)
2023-06-25 06:44:36 +02:00
from transformers import (
AutoConfig ,
AutoModel ,
AutoModelForCausalLM ,
AutoModelForSeq2SeqLM ,
AutoTokenizer ,
2023-09-25 05:03:11 +02:00
BitsAndBytesConfig ,
2025-01-02 04:06:11 +01:00
GPTQConfig ,
is_torch_npu_available ,
is_torch_xpu_available
2023-06-25 06:44:36 +02:00
)
2023-02-23 18:41:42 +01:00
import modules . shared as shared
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-09-11 23:49:30 +02:00
from modules . models_settings import get_model_metadata
2023-02-23 17:28:30 +01:00
2023-02-23 17:42:23 +01:00
transformers . logging . set_verbosity_error ( )
2023-04-08 02:36:04 +02:00
local_rank = None
2023-02-23 17:28:30 +01:00
if shared . args . deepspeed :
import deepspeed
2023-06-25 06:44:36 +02:00
from transformers . deepspeed import (
HfDeepSpeedConfig ,
is_deepspeed_zero3_enabled
)
2023-02-23 18:41:42 +01:00
2023-02-23 17:28:30 +01:00
from modules . deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
2023-10-27 04:39:51 +02:00
if is_xpu_available ( ) and is_ccl_available ( ) :
torch . xpu . set_device ( local_rank )
deepspeed . init_distributed ( backend = " ccl " )
2024-04-11 23:42:20 +02:00
elif is_npu_available ( ) :
torch . npu . set_device ( local_rank )
deepspeed . init_distributed ( dist_backend = " hccl " )
2023-10-27 04:39:51 +02:00
else :
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
2023-02-23 17:28:30 +01:00
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-04-07 05:15:45 +02:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 17:28:30 +01:00
2023-03-13 18:00:38 +01:00
2024-05-20 04:29:39 +02:00
last_generation_time = time . time ( )
2023-06-17 00:00:37 +02:00
def load_model ( model_name , loader = None ) :
2024-02-06 17:22:08 +01:00
logger . info ( f " Loading \" { model_name } \" " )
2023-02-23 17:28:30 +01:00
t0 = time . time ( )
2023-06-17 00:00:37 +02:00
shared . is_seq2seq = False
2023-12-08 15:35:23 +01:00
shared . model_name = model_name
2023-06-17 00:00:37 +02:00
load_func_map = {
' Transformers ' : huggingface_loader ,
' llama.cpp ' : llamacpp_loader ,
2023-07-16 07:21:13 +02:00
' llamacpp_HF ' : llamacpp_HF_loader ,
2024-02-06 15:21:17 +01:00
' ExLlamav2 ' : ExLlamav2_loader ,
2023-09-12 19:33:07 +02:00
' ExLlamav2_HF ' : ExLlamav2_HF_loader ,
2024-09-29 05:30:24 +02:00
' AutoGPTQ ' : AutoGPTQ_loader ,
2023-12-19 01:23:16 +01:00
' HQQ ' : HQQ_loader ,
2024-06-24 07:30:03 +02:00
' TensorRT-LLM ' : TensorRT_LLM_loader ,
2023-06-17 00:00:37 +02:00
}
2023-11-06 06:38:29 +01:00
metadata = get_model_metadata ( model_name )
2023-06-17 00:00:37 +02:00
if loader is None :
if shared . args . loader is not None :
loader = shared . args . loader
else :
2023-11-06 06:38:29 +01:00
loader = metadata [ ' loader ' ]
2023-06-17 00:00:37 +02:00
if loader is None :
logger . error ( ' The path to the model does not exist. Exiting. ' )
2023-11-08 05:58:06 +01:00
raise ValueError
2023-05-17 00:52:22 +02:00
2023-06-17 00:00:37 +02:00
shared . args . loader = loader
2025-01-01 22:33:38 +01:00
clear_torch_cache ( )
2023-06-17 00:00:37 +02:00
output = load_func_map [ loader ] ( model_name )
2023-05-17 00:52:22 +02:00
if type ( output ) is tuple :
model , tokenizer = output
else :
model = output
2023-05-19 16:20:08 +02:00
if model is None :
return None , None
else :
2024-08-07 04:41:18 +02:00
tokenizer = load_tokenizer ( model_name )
2023-05-17 00:52:22 +02:00
2023-11-06 06:38:29 +01:00
shared . settings . update ( { k : v for k , v in metadata . items ( ) if k in shared . settings } )
2024-06-24 07:30:03 +02:00
if loader . lower ( ) . startswith ( ' exllama ' ) or loader . lower ( ) . startswith ( ' tensorrt ' ) :
2023-11-16 01:00:51 +01:00
shared . settings [ ' truncation_length ' ] = shared . args . max_seq_len
2024-04-05 01:23:58 +02:00
elif loader in [ ' llama.cpp ' , ' llamacpp_HF ' ] :
2023-11-16 01:00:51 +01:00
shared . settings [ ' truncation_length ' ] = shared . args . n_ctx
2024-05-03 17:10:44 +02:00
logger . info ( f " Loaded \" { model_name } \" in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2024-02-06 15:31:27 +01:00
logger . info ( f " LOADER: \" { loader } \" " )
2023-11-16 01:13:36 +01:00
logger . info ( f " TRUNCATION LENGTH: { shared . settings [ ' truncation_length ' ] } " )
2024-02-06 15:31:27 +01:00
logger . info ( f " INSTRUCTION TEMPLATE: \" { metadata [ ' instruction_template ' ] } \" " )
2023-05-17 00:52:22 +02:00
return model , tokenizer
2024-08-07 04:41:18 +02:00
def load_tokenizer ( model_name , tokenizer_dir = None ) :
if tokenizer_dir :
path_to_model = Path ( tokenizer_dir )
else :
path_to_model = Path ( f " { shared . args . model_dir } / { model_name } / " )
2023-05-17 20:52:23 +02:00
tokenizer = None
2024-04-05 01:10:47 +02:00
if path_to_model . exists ( ) :
2023-11-17 04:45:05 +01:00
if shared . args . no_use_fast :
logger . info ( ' Loading the tokenizer with use_fast=False. ' )
2023-09-25 21:19:43 +02:00
tokenizer = AutoTokenizer . from_pretrained (
path_to_model ,
trust_remote_code = shared . args . trust_remote_code ,
2023-11-17 04:45:05 +01:00
use_fast = not shared . args . no_use_fast
2023-09-25 21:19:43 +02:00
)
2023-07-05 04:43:19 +02:00
2023-05-17 00:52:22 +02:00
return tokenizer
def huggingface_loader ( model_name ) :
2023-06-17 00:00:37 +02:00
path_to_model = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-09-25 05:23:05 +02:00
params = {
' low_cpu_mem_usage ' : True ,
2023-11-02 20:20:54 +01:00
' torch_dtype ' : torch . bfloat16 if shared . args . bf16 else torch . float16 ,
2023-09-25 05:23:05 +02:00
}
2023-11-04 17:59:33 +01:00
2024-04-06 03:40:02 +02:00
if shared . args . trust_remote_code :
params [ ' trust_remote_code ' ] = True
2023-11-04 17:59:33 +01:00
if shared . args . use_flash_attention_2 :
params [ ' use_flash_attention_2 ' ] = True
2024-04-06 03:40:02 +02:00
if shared . args . force_safetensors :
2024-04-06 03:43:43 +02:00
params [ ' force_safetensors ' ] = True
2024-04-06 03:40:02 +02:00
2024-07-01 17:08:08 +02:00
if shared . args . use_eager_attention :
2024-07-05 05:16:44 +02:00
params [ ' attn_implementation ' ] = ' eager '
2024-07-01 17:08:08 +02:00
2024-04-06 03:40:02 +02:00
config = AutoConfig . from_pretrained ( path_to_model , trust_remote_code = shared . args . trust_remote_code )
2023-09-25 05:03:11 +02:00
2023-06-17 00:00:37 +02:00
if ' chatglm ' in model_name . lower ( ) :
2023-04-17 00:15:03 +02:00
LoaderClass = AutoModel
else :
2023-09-25 05:23:05 +02:00
if config . to_dict ( ) . get ( ' is_encoder_decoder ' , False ) :
2023-06-17 00:00:37 +02:00
LoaderClass = AutoModelForSeq2SeqLM
shared . is_seq2seq = True
else :
LoaderClass = AutoModelForCausalLM
2023-02-28 03:03:35 +01:00
2024-04-05 03:09:34 +02:00
# Load the model without any special settings
2023-12-15 15:46:13 +01:00
if not any ( [ shared . args . cpu , shared . args . load_in_8bit , shared . args . load_in_4bit , shared . args . auto_devices , shared . args . disk , shared . args . deepspeed , shared . args . gpu_memory is not None , shared . args . cpu_memory is not None , shared . args . compress_pos_emb > 1 , shared . args . alpha_value > 1 , shared . args . disable_exllama , shared . args . disable_exllamav2 ] ) :
2024-04-06 03:40:02 +02:00
logger . info ( " TRANSFORMERS_PARAMS= " )
pprint . PrettyPrinter ( indent = 4 , sort_dicts = False ) . pprint ( params )
2024-04-06 21:57:57 +02:00
print ( )
2023-09-25 05:23:05 +02:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2024-04-05 03:09:34 +02:00
if not ( hasattr ( model , ' is_loaded_in_4bit ' ) and model . is_loaded_in_4bit ) :
2025-01-02 04:06:11 +01:00
device = get_device ( )
if device :
2024-04-11 23:42:20 +02:00
model = model . to ( device )
2023-03-18 02:27:26 +01:00
2023-02-23 17:28:30 +01:00
# DeepSpeed ZeRO-3
elif shared . args . deepspeed :
2024-04-24 06:09:14 +02:00
model = LoaderClass . from_pretrained ( path_to_model , torch_dtype = params [ ' torch_dtype ' ] , trust_remote_code = params . get ( ' trust_remote_code ' ) )
2023-02-23 17:28:30 +01:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-04-07 05:15:45 +02:00
model . module . eval ( ) # Inference
2023-09-25 05:23:05 +02:00
logger . info ( f ' DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } ' )
2023-02-23 17:28:30 +01:00
2023-09-25 05:23:05 +02:00
# Load with quantization and/or offloading
2023-02-23 17:28:30 +01:00
else :
2023-10-27 04:39:51 +02:00
if not any ( ( shared . args . cpu , torch . cuda . is_available ( ) , is_xpu_available ( ) , torch . backends . mps . is_available ( ) ) ) :
logger . warning ( ' torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode. ' )
2023-02-23 17:28:30 +01:00
shared . args . cpu = True
if shared . args . cpu :
2023-09-25 05:23:05 +02:00
params [ ' torch_dtype ' ] = torch . float32
2023-02-23 17:28:30 +01:00
else :
2023-09-25 05:23:05 +02:00
params [ ' device_map ' ] = ' auto '
2024-04-24 18:53:41 +02:00
if x := get_max_memory_dict ( ) :
2024-04-06 03:40:02 +02:00
params [ ' max_memory ' ] = x
2023-05-25 06:14:13 +02:00
if shared . args . load_in_4bit :
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
' load_in_4bit ' : True ,
' bnb_4bit_compute_dtype ' : eval ( " torch. {} " . format ( shared . args . compute_dtype ) ) if shared . args . compute_dtype in [ " bfloat16 " , " float16 " , " float32 " ] else None ,
' bnb_4bit_quant_type ' : shared . args . quant_type ,
' bnb_4bit_use_double_quant ' : shared . args . use_double_quant ,
2024-04-26 18:39:27 +02:00
' llm_int8_enable_fp32_cpu_offload ' : True
2023-05-25 06:14:13 +02:00
}
params [ ' quantization_config ' ] = BitsAndBytesConfig ( * * quantization_config_params )
2023-03-16 22:22:16 +01:00
elif shared . args . load_in_8bit :
2023-09-25 05:03:11 +02:00
if any ( ( shared . args . auto_devices , shared . args . gpu_memory ) ) :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True , llm_int8_enable_fp32_cpu_offload = True )
else :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True )
2023-02-23 17:28:30 +01:00
2024-04-24 06:09:14 +02:00
if params . get ( ' max_memory ' ) is not None :
2023-09-25 05:23:05 +02:00
with init_empty_weights ( ) :
2024-04-24 06:09:14 +02:00
model = LoaderClass . from_config ( config , trust_remote_code = params . get ( ' trust_remote_code ' ) )
2023-09-25 05:23:05 +02:00
model . tie_weights ( )
params [ ' device_map ' ] = infer_auto_device_map (
model ,
dtype = torch . int8 ,
2024-04-24 06:09:14 +02:00
max_memory = params . get ( ' max_memory ' ) ,
2023-09-25 05:23:05 +02:00
no_split_module_classes = model . _no_split_modules
)
2023-03-16 16:42:53 +01:00
if shared . args . disk :
2023-09-25 05:23:05 +02:00
params [ ' offload_folder ' ] = shared . args . disk_cache_dir
2023-03-16 16:42:53 +01:00
2023-12-15 15:46:13 +01:00
if shared . args . disable_exllama or shared . args . disable_exllamav2 :
2023-09-25 05:03:11 +02:00
try :
2023-12-15 15:46:13 +01:00
gptq_config = GPTQConfig (
bits = config . quantization_config . get ( ' bits ' , 4 ) ,
disable_exllama = shared . args . disable_exllama ,
disable_exllamav2 = shared . args . disable_exllamav2 ,
)
2023-09-25 05:03:11 +02:00
params [ ' quantization_config ' ] = gptq_config
2023-12-15 15:46:13 +01:00
logger . info ( f ' Loading with disable_exllama= { shared . args . disable_exllama } and disable_exllamav2= { shared . args . disable_exllamav2 } . ' )
2023-09-25 05:03:11 +02:00
except :
2023-09-25 05:23:05 +02:00
exc = traceback . format_exc ( )
2023-09-25 05:03:11 +02:00
logger . error ( ' Failed to disable exllama. Does the config.json for this model contain the necessary quantization info? ' )
2023-09-25 05:23:05 +02:00
print ( exc )
2023-03-16 16:42:53 +01:00
2023-08-09 06:24:28 +02:00
if shared . args . compress_pos_emb > 1 :
params [ ' rope_scaling ' ] = { ' type ' : ' linear ' , ' factor ' : shared . args . compress_pos_emb }
elif shared . args . alpha_value > 1 :
2024-06-24 07:09:24 +02:00
params [ ' rope_scaling ' ] = { ' type ' : ' dynamic ' , ' factor ' : shared . args . alpha_value }
2023-08-09 06:24:28 +02:00
2024-04-06 03:40:02 +02:00
logger . info ( " TRANSFORMERS_PARAMS= " )
pprint . PrettyPrinter ( indent = 4 , sort_dicts = False ) . pprint ( params )
2024-04-07 03:56:58 +02:00
print ( )
2023-09-25 05:03:11 +02:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2023-02-23 17:28:30 +01:00
2023-05-17 00:52:22 +02:00
return model
2023-04-10 04:08:40 +02:00
2023-04-20 02:23:51 +02:00
2023-05-17 00:52:22 +02:00
def llamacpp_loader ( model_name ) :
from modules . llamacpp_model import LlamaCppModel
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if path . is_file ( ) :
model_file = path
2023-02-23 17:28:30 +01:00
else :
2024-06-13 04:00:21 +02:00
model_file = sorted ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *.gguf ' ) ) [ 0 ]
2023-02-23 17:28:30 +01:00
2024-02-06 17:22:08 +01:00
logger . info ( f " llama.cpp weights detected: \" { model_file } \" " )
2023-05-17 00:52:22 +02:00
model , tokenizer = LlamaCppModel . from_pretrained ( model_file )
2023-02-23 17:28:30 +01:00
return model , tokenizer
2023-04-07 05:15:45 +02:00
2023-07-16 07:21:13 +02:00
def llamacpp_HF_loader ( model_name ) :
from modules . llamacpp_hf import LlamacppHF
2024-08-07 04:41:18 +02:00
if shared . args . tokenizer_dir :
logger . info ( f ' Using tokenizer from: \" { shared . args . tokenizer_dir } \" ' )
2023-07-16 07:21:13 +02:00
else :
2024-08-07 04:41:18 +02:00
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
# Check if a HF tokenizer is available for the model
if all ( ( path / file ) . exists ( ) for file in [ ' tokenizer_config.json ' ] ) :
logger . info ( f ' Using tokenizer from: \" { path } \" ' )
else :
logger . error ( " Could not load the model because a tokenizer in Transformers format was not found. " )
return None , None
2023-07-16 07:21:13 +02:00
model = LlamacppHF . from_pretrained ( model_name )
2024-08-07 04:41:18 +02:00
if shared . args . tokenizer_dir :
tokenizer = load_tokenizer ( model_name , tokenizer_dir = shared . args . tokenizer_dir )
return model , tokenizer
else :
return model
2023-07-16 07:21:13 +02:00
2024-02-06 15:21:17 +01:00
def ExLlamav2_loader ( model_name ) :
from modules . exllamav2 import Exllamav2Model
model , tokenizer = Exllamav2Model . from_pretrained ( model_name )
return model , tokenizer
2023-09-12 19:33:07 +02:00
def ExLlamav2_HF_loader ( model_name ) :
from modules . exllamav2_hf import Exllamav2HF
return Exllamav2HF . from_pretrained ( model_name )
2024-09-29 05:30:24 +02:00
def AutoGPTQ_loader ( model_name ) :
try :
import modules . AutoGPTQ_loader
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Failed to import ' autogptq ' . Please install it manually following the instructions in the AutoGPTQ GitHub repository. " )
return modules . AutoGPTQ_loader . load_quantized ( model_name )
2023-12-19 01:23:16 +01:00
def HQQ_loader ( model_name ) :
2024-09-29 05:30:24 +02:00
try :
from hqq . core . quantize import HQQBackend , HQQLinear
from hqq . models . hf . base import AutoHQQHFModel
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Failed to import ' hqq ' . Please install it manually following the instructions in the HQQ GitHub repository. " )
2023-12-19 01:23:16 +01:00
2024-02-06 17:22:08 +01:00
logger . info ( f " Loading HQQ model with backend: \" { shared . args . hqq_backend } \" " )
2023-12-19 01:23:16 +01:00
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2024-05-21 18:32:02 +02:00
model = AutoHQQHFModel . from_quantized ( str ( model_dir ) )
2023-12-19 01:23:16 +01:00
HQQLinear . set_backend ( getattr ( HQQBackend , shared . args . hqq_backend ) )
return model
2024-06-24 07:30:03 +02:00
def TensorRT_LLM_loader ( model_name ) :
2024-09-29 05:30:24 +02:00
try :
from modules . tensorrt_llm import TensorRTLLMModel
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Failed to import ' tensorrt_llm ' . Please install it manually following the instructions in the TensorRT-LLM GitHub repository. " )
2024-06-24 07:30:03 +02:00
model = TensorRTLLMModel . from_pretrained ( model_name )
return model
2023-05-16 00:38:27 +02:00
def get_max_memory_dict ( ) :
2023-11-16 01:01:54 +01:00
max_memory = { }
2023-11-16 01:04:02 +01:00
max_cpu_memory = shared . args . cpu_memory . strip ( ) if shared . args . cpu_memory is not None else ' 99GiB '
2023-05-16 00:38:27 +02:00
if shared . args . gpu_memory :
memory_map = list ( map ( lambda x : x . strip ( ) , shared . args . gpu_memory ) )
for i in range ( len ( memory_map ) ) :
max_memory [ i ] = f ' { memory_map [ i ] } GiB ' if not re . match ( ' .*ib$ ' , memory_map [ i ] . lower ( ) ) else memory_map [ i ]
2023-11-16 01:01:54 +01:00
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-16 00:38:27 +02:00
# If --auto-devices is provided standalone, try to get a reasonable value
# for the maximum memory of device :0
elif shared . args . auto_devices :
2023-10-27 04:39:51 +02:00
if is_xpu_available ( ) :
total_mem = ( torch . xpu . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
else :
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
2023-11-16 01:00:51 +01:00
2023-05-16 00:38:27 +02:00
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
if total_mem - suggestion < 800 :
suggestion - = 1000
suggestion = int ( round ( suggestion / 1000 ) )
2023-05-22 03:42:34 +02:00
logger . warning ( f " Auto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. You can manually set other values. " )
2023-11-16 01:04:02 +01:00
max_memory [ 0 ] = f ' { suggestion } GiB '
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-16 00:38:27 +02:00
2023-05-17 16:12:12 +02:00
return max_memory if len ( max_memory ) > 0 else None
2023-05-16 00:38:27 +02:00
2025-01-02 04:06:11 +01:00
def get_device ( ) :
if torch . cuda . is_available ( ) :
return torch . device ( ' cuda ' )
elif shared . args . deepspeed :
import deepspeed
return deepspeed . get_accelerator ( ) . current_device_name ( )
elif torch . backends . mps . is_available ( ) :
return torch . device ( ' mps ' )
elif is_torch_xpu_available ( ) :
return torch . device ( ' xpu:0 ' )
elif is_torch_npu_available ( ) :
return torch . device ( ' npu:0 ' )
else :
return None
2023-04-08 02:36:04 +02:00
def clear_torch_cache ( ) :
gc . collect ( )
if not shared . args . cpu :
2025-01-02 04:06:11 +01:00
if torch . cuda . is_available ( ) :
2023-10-27 04:39:51 +02:00
torch . cuda . empty_cache ( )
2025-01-02 04:06:11 +01:00
elif is_xpu_available ( ) :
torch . xpu . empty_cache ( )
elif is_npu_available ( ) :
torch . npu . empty_cache ( )
elif torch . backends . mps . is_available ( ) :
if hasattr ( torch . backends . mps , ' empty_cache ' ) :
torch . backends . mps . empty_cache ( )
2023-04-08 02:36:04 +02:00
2024-07-29 03:30:06 +02:00
def unload_model ( keep_model_name = False ) :
2023-04-08 02:36:04 +02:00
shared . model = shared . tokenizer = None
2023-07-03 22:39:06 +02:00
shared . lora_names = [ ]
2023-07-12 20:29:43 +02:00
shared . model_dirty_from_training = False
2023-04-08 02:36:04 +02:00
clear_torch_cache ( )
2024-07-29 03:30:06 +02:00
if not keep_model_name :
shared . model_name = ' None '
2023-04-08 02:36:04 +02:00
def reload_model ( ) :
2023-04-08 02:37:41 +02:00
unload_model ( )
2023-04-08 02:36:04 +02:00
shared . model , shared . tokenizer = load_model ( shared . model_name )
2024-05-20 04:29:39 +02:00
def unload_model_if_idle ( ) :
global last_generation_time
logger . info ( f " Setting a timeout of { shared . args . idle_timeout } minutes to unload the model in case of inactivity. " )
while True :
shared . generation_lock . acquire ( )
try :
if time . time ( ) - last_generation_time > shared . args . idle_timeout * 60 :
if shared . model is not None :
logger . info ( " Unloading the model for inactivity. " )
2024-07-29 03:30:06 +02:00
unload_model ( keep_model_name = True )
2024-05-20 04:29:39 +02:00
finally :
shared . generation_lock . release ( )
time . sleep ( 60 )