2023-04-08 02:36:04 +02:00
import gc
2023-12-06 04:01:01 +01:00
import logging
2023-02-23 17:28:30 +01:00
import os
2024-04-06 03:40:02 +02:00
import pprint
2023-03-19 23:21:41 +01:00
import re
2023-02-23 17:28:30 +01:00
import time
2023-09-25 05:23:05 +02:00
import traceback
2023-02-23 17:28:30 +01:00
from pathlib import Path
import torch
2023-02-23 17:42:23 +01:00
import transformers
2023-10-27 05:26:25 +02:00
from accelerate import infer_auto_device_map , init_empty_weights
2024-04-11 23:42:20 +02:00
from accelerate . utils import (
is_ccl_available ,
is_npu_available ,
is_xpu_available
)
2023-06-25 06:44:36 +02:00
from transformers import (
AutoConfig ,
AutoModel ,
AutoModelForCausalLM ,
AutoModelForSeq2SeqLM ,
AutoTokenizer ,
2023-09-25 05:03:11 +02:00
BitsAndBytesConfig ,
GPTQConfig
2023-06-25 06:44:36 +02:00
)
2023-02-23 18:41:42 +01:00
import modules . shared as shared
2023-12-31 05:36:51 +01:00
from modules import RoPE , sampler_hijack
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-09-11 23:49:30 +02:00
from modules . models_settings import get_model_metadata
2023-12-06 04:01:01 +01:00
from modules . relative_imports import RelativeImport
2023-02-23 17:28:30 +01:00
2023-02-23 17:42:23 +01:00
transformers . logging . set_verbosity_error ( )
2023-04-08 02:36:04 +02:00
local_rank = None
2023-02-23 17:28:30 +01:00
if shared . args . deepspeed :
import deepspeed
2023-06-25 06:44:36 +02:00
from transformers . deepspeed import (
HfDeepSpeedConfig ,
is_deepspeed_zero3_enabled
)
2023-02-23 18:41:42 +01:00
2023-02-23 17:28:30 +01:00
from modules . deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
2023-10-27 04:39:51 +02:00
if is_xpu_available ( ) and is_ccl_available ( ) :
torch . xpu . set_device ( local_rank )
deepspeed . init_distributed ( backend = " ccl " )
2024-04-11 23:42:20 +02:00
elif is_npu_available ( ) :
torch . npu . set_device ( local_rank )
deepspeed . init_distributed ( dist_backend = " hccl " )
2023-10-27 04:39:51 +02:00
else :
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
2023-02-23 17:28:30 +01:00
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-04-07 05:15:45 +02:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 17:28:30 +01:00
2023-05-30 02:40:01 +02:00
sampler_hijack . hijack_samplers ( )
2023-03-13 18:00:38 +01:00
2023-06-17 00:00:37 +02:00
def load_model ( model_name , loader = None ) :
2024-02-06 17:22:08 +01:00
logger . info ( f " Loading \" { model_name } \" " )
2023-02-23 17:28:30 +01:00
t0 = time . time ( )
2023-06-17 00:00:37 +02:00
shared . is_seq2seq = False
2023-12-08 15:35:23 +01:00
shared . model_name = model_name
2023-06-17 00:00:37 +02:00
load_func_map = {
' Transformers ' : huggingface_loader ,
' AutoGPTQ ' : AutoGPTQ_loader ,
' GPTQ-for-LLaMa ' : GPTQ_loader ,
' llama.cpp ' : llamacpp_loader ,
2023-07-16 07:21:13 +02:00
' llamacpp_HF ' : llamacpp_HF_loader ,
2024-02-06 15:21:17 +01:00
' ExLlamav2 ' : ExLlamav2_loader ,
2023-09-12 19:33:07 +02:00
' ExLlamav2_HF ' : ExLlamav2_HF_loader ,
2023-10-05 18:19:18 +02:00
' AutoAWQ ' : AutoAWQ_loader ,
2023-12-06 04:01:01 +01:00
' QuIP# ' : QuipSharp_loader ,
2023-12-19 01:23:16 +01:00
' HQQ ' : HQQ_loader ,
2023-06-17 00:00:37 +02:00
}
2023-11-06 06:38:29 +01:00
metadata = get_model_metadata ( model_name )
2023-06-17 00:00:37 +02:00
if loader is None :
if shared . args . loader is not None :
loader = shared . args . loader
else :
2023-11-06 06:38:29 +01:00
loader = metadata [ ' loader ' ]
2023-06-17 00:00:37 +02:00
if loader is None :
logger . error ( ' The path to the model does not exist. Exiting. ' )
2023-11-08 05:58:06 +01:00
raise ValueError
2023-05-17 00:52:22 +02:00
2023-06-17 00:00:37 +02:00
shared . args . loader = loader
output = load_func_map [ loader ] ( model_name )
2023-05-17 00:52:22 +02:00
if type ( output ) is tuple :
model , tokenizer = output
else :
model = output
2023-05-19 16:20:08 +02:00
if model is None :
return None , None
else :
tokenizer = load_tokenizer ( model_name , model )
2023-05-17 00:52:22 +02:00
2023-11-06 06:38:29 +01:00
shared . settings . update ( { k : v for k , v in metadata . items ( ) if k in shared . settings } )
2023-11-16 01:00:51 +01:00
if loader . lower ( ) . startswith ( ' exllama ' ) :
shared . settings [ ' truncation_length ' ] = shared . args . max_seq_len
2024-04-05 01:23:58 +02:00
elif loader in [ ' llama.cpp ' , ' llamacpp_HF ' ] :
2023-11-16 01:00:51 +01:00
shared . settings [ ' truncation_length ' ] = shared . args . n_ctx
2024-05-03 17:10:44 +02:00
logger . info ( f " Loaded \" { model_name } \" in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2024-02-06 15:31:27 +01:00
logger . info ( f " LOADER: \" { loader } \" " )
2023-11-16 01:13:36 +01:00
logger . info ( f " TRUNCATION LENGTH: { shared . settings [ ' truncation_length ' ] } " )
2024-02-06 15:31:27 +01:00
logger . info ( f " INSTRUCTION TEMPLATE: \" { metadata [ ' instruction_template ' ] } \" " )
2023-05-17 00:52:22 +02:00
return model , tokenizer
def load_tokenizer ( model_name , model ) :
2023-05-17 20:52:23 +02:00
tokenizer = None
2023-07-05 04:43:19 +02:00
path_to_model = Path ( f " { shared . args . model_dir } / { model_name } / " )
2024-04-05 01:10:47 +02:00
if path_to_model . exists ( ) :
2023-11-17 04:45:05 +01:00
if shared . args . no_use_fast :
logger . info ( ' Loading the tokenizer with use_fast=False. ' )
2023-09-25 21:19:43 +02:00
tokenizer = AutoTokenizer . from_pretrained (
path_to_model ,
trust_remote_code = shared . args . trust_remote_code ,
2023-11-17 04:45:05 +01:00
use_fast = not shared . args . no_use_fast
2023-09-25 21:19:43 +02:00
)
2023-07-05 04:43:19 +02:00
2023-05-17 00:52:22 +02:00
return tokenizer
def huggingface_loader ( model_name ) :
2023-06-17 00:00:37 +02:00
path_to_model = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-09-25 05:23:05 +02:00
params = {
' low_cpu_mem_usage ' : True ,
2023-11-02 20:20:54 +01:00
' torch_dtype ' : torch . bfloat16 if shared . args . bf16 else torch . float16 ,
2023-09-25 05:23:05 +02:00
}
2023-11-04 17:59:33 +01:00
2024-04-06 03:40:02 +02:00
if shared . args . trust_remote_code :
params [ ' trust_remote_code ' ] = True
2023-11-04 17:59:33 +01:00
if shared . args . use_flash_attention_2 :
params [ ' use_flash_attention_2 ' ] = True
2024-04-06 03:40:02 +02:00
if shared . args . force_safetensors :
2024-04-06 03:43:43 +02:00
params [ ' force_safetensors ' ] = True
2024-04-06 03:40:02 +02:00
config = AutoConfig . from_pretrained ( path_to_model , trust_remote_code = shared . args . trust_remote_code )
2023-09-25 05:03:11 +02:00
2023-06-17 00:00:37 +02:00
if ' chatglm ' in model_name . lower ( ) :
2023-04-17 00:15:03 +02:00
LoaderClass = AutoModel
else :
2023-09-25 05:23:05 +02:00
if config . to_dict ( ) . get ( ' is_encoder_decoder ' , False ) :
2023-06-17 00:00:37 +02:00
LoaderClass = AutoModelForSeq2SeqLM
shared . is_seq2seq = True
else :
LoaderClass = AutoModelForCausalLM
2023-02-28 03:03:35 +01:00
2024-04-05 03:09:34 +02:00
# Load the model without any special settings
2023-12-15 15:46:13 +01:00
if not any ( [ shared . args . cpu , shared . args . load_in_8bit , shared . args . load_in_4bit , shared . args . auto_devices , shared . args . disk , shared . args . deepspeed , shared . args . gpu_memory is not None , shared . args . cpu_memory is not None , shared . args . compress_pos_emb > 1 , shared . args . alpha_value > 1 , shared . args . disable_exllama , shared . args . disable_exllamav2 ] ) :
2024-04-06 03:40:02 +02:00
logger . info ( " TRANSFORMERS_PARAMS= " )
pprint . PrettyPrinter ( indent = 4 , sort_dicts = False ) . pprint ( params )
2024-04-06 21:57:57 +02:00
print ( )
2023-09-25 05:23:05 +02:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2024-04-05 03:09:34 +02:00
if not ( hasattr ( model , ' is_loaded_in_4bit ' ) and model . is_loaded_in_4bit ) :
if torch . backends . mps . is_available ( ) :
device = torch . device ( ' mps ' )
model = model . to ( device )
elif is_xpu_available ( ) :
device = torch . device ( " xpu " )
model = model . to ( device )
2024-04-11 23:42:20 +02:00
elif is_npu_available ( ) :
device = torch . device ( " npu " )
model = model . to ( device )
2024-04-05 03:09:34 +02:00
else :
model = model . cuda ( )
2023-03-18 02:27:26 +01:00
2023-02-23 17:28:30 +01:00
# DeepSpeed ZeRO-3
elif shared . args . deepspeed :
2024-04-24 06:09:14 +02:00
model = LoaderClass . from_pretrained ( path_to_model , torch_dtype = params [ ' torch_dtype ' ] , trust_remote_code = params . get ( ' trust_remote_code ' ) )
2023-02-23 17:28:30 +01:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-04-07 05:15:45 +02:00
model . module . eval ( ) # Inference
2023-09-25 05:23:05 +02:00
logger . info ( f ' DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } ' )
2023-02-23 17:28:30 +01:00
2023-09-25 05:23:05 +02:00
# Load with quantization and/or offloading
2023-02-23 17:28:30 +01:00
else :
2023-10-27 04:39:51 +02:00
if not any ( ( shared . args . cpu , torch . cuda . is_available ( ) , is_xpu_available ( ) , torch . backends . mps . is_available ( ) ) ) :
logger . warning ( ' torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode. ' )
2023-02-23 17:28:30 +01:00
shared . args . cpu = True
if shared . args . cpu :
2023-09-25 05:23:05 +02:00
params [ ' torch_dtype ' ] = torch . float32
2023-02-23 17:28:30 +01:00
else :
2023-09-25 05:23:05 +02:00
params [ ' device_map ' ] = ' auto '
2024-04-24 18:53:41 +02:00
if x := get_max_memory_dict ( ) :
2024-04-06 03:40:02 +02:00
params [ ' max_memory ' ] = x
2023-05-25 06:14:13 +02:00
if shared . args . load_in_4bit :
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
' load_in_4bit ' : True ,
' bnb_4bit_compute_dtype ' : eval ( " torch. {} " . format ( shared . args . compute_dtype ) ) if shared . args . compute_dtype in [ " bfloat16 " , " float16 " , " float32 " ] else None ,
' bnb_4bit_quant_type ' : shared . args . quant_type ,
' bnb_4bit_use_double_quant ' : shared . args . use_double_quant ,
2024-04-26 18:39:27 +02:00
' llm_int8_enable_fp32_cpu_offload ' : True
2023-05-25 06:14:13 +02:00
}
params [ ' quantization_config ' ] = BitsAndBytesConfig ( * * quantization_config_params )
2023-03-16 22:22:16 +01:00
elif shared . args . load_in_8bit :
2023-09-25 05:03:11 +02:00
if any ( ( shared . args . auto_devices , shared . args . gpu_memory ) ) :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True , llm_int8_enable_fp32_cpu_offload = True )
else :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True )
2023-02-23 17:28:30 +01:00
2024-04-24 06:09:14 +02:00
if params . get ( ' max_memory ' ) is not None :
2023-09-25 05:23:05 +02:00
with init_empty_weights ( ) :
2024-04-24 06:09:14 +02:00
model = LoaderClass . from_config ( config , trust_remote_code = params . get ( ' trust_remote_code ' ) )
2023-09-25 05:23:05 +02:00
model . tie_weights ( )
params [ ' device_map ' ] = infer_auto_device_map (
model ,
dtype = torch . int8 ,
2024-04-24 06:09:14 +02:00
max_memory = params . get ( ' max_memory ' ) ,
2023-09-25 05:23:05 +02:00
no_split_module_classes = model . _no_split_modules
)
2023-03-16 16:42:53 +01:00
if shared . args . disk :
2023-09-25 05:23:05 +02:00
params [ ' offload_folder ' ] = shared . args . disk_cache_dir
2023-03-16 16:42:53 +01:00
2023-12-15 15:46:13 +01:00
if shared . args . disable_exllama or shared . args . disable_exllamav2 :
2023-09-25 05:03:11 +02:00
try :
2023-12-15 15:46:13 +01:00
gptq_config = GPTQConfig (
bits = config . quantization_config . get ( ' bits ' , 4 ) ,
disable_exllama = shared . args . disable_exllama ,
disable_exllamav2 = shared . args . disable_exllamav2 ,
)
2023-09-25 05:03:11 +02:00
params [ ' quantization_config ' ] = gptq_config
2023-12-15 15:46:13 +01:00
logger . info ( f ' Loading with disable_exllama= { shared . args . disable_exllama } and disable_exllamav2= { shared . args . disable_exllamav2 } . ' )
2023-09-25 05:03:11 +02:00
except :
2023-09-25 05:23:05 +02:00
exc = traceback . format_exc ( )
2023-09-25 05:03:11 +02:00
logger . error ( ' Failed to disable exllama. Does the config.json for this model contain the necessary quantization info? ' )
2023-09-25 05:23:05 +02:00
print ( exc )
2023-03-16 16:42:53 +01:00
2023-08-09 06:24:28 +02:00
if shared . args . compress_pos_emb > 1 :
params [ ' rope_scaling ' ] = { ' type ' : ' linear ' , ' factor ' : shared . args . compress_pos_emb }
elif shared . args . alpha_value > 1 :
2023-08-25 15:53:37 +02:00
params [ ' rope_scaling ' ] = { ' type ' : ' dynamic ' , ' factor ' : RoPE . get_alpha_value ( shared . args . alpha_value , shared . args . rope_freq_base ) }
2023-08-09 06:24:28 +02:00
2024-04-06 03:40:02 +02:00
logger . info ( " TRANSFORMERS_PARAMS= " )
pprint . PrettyPrinter ( indent = 4 , sort_dicts = False ) . pprint ( params )
2024-04-07 03:56:58 +02:00
print ( )
2023-09-25 05:03:11 +02:00
model = LoaderClass . from_pretrained ( path_to_model , * * params )
2023-02-23 17:28:30 +01:00
2023-05-17 00:52:22 +02:00
return model
2023-04-10 04:08:40 +02:00
2023-04-20 02:23:51 +02:00
2023-05-17 00:52:22 +02:00
def llamacpp_loader ( model_name ) :
from modules . llamacpp_model import LlamaCppModel
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if path . is_file ( ) :
model_file = path
2023-02-23 17:28:30 +01:00
else :
2024-05-20 01:22:09 +02:00
model_file = sorted ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *.gguf ' ) ) [ 0 ]
2023-02-23 17:28:30 +01:00
2024-02-06 17:22:08 +01:00
logger . info ( f " llama.cpp weights detected: \" { model_file } \" " )
2023-05-17 00:52:22 +02:00
model , tokenizer = LlamaCppModel . from_pretrained ( model_file )
2023-02-23 17:28:30 +01:00
return model , tokenizer
2023-04-07 05:15:45 +02:00
2023-07-16 07:21:13 +02:00
def llamacpp_HF_loader ( model_name ) :
from modules . llamacpp_hf import LlamacppHF
2024-02-14 04:28:51 +01:00
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
# Check if a HF tokenizer is available for the model
2024-02-16 18:29:26 +01:00
if all ( ( path / file ) . exists ( ) for file in [ ' tokenizer_config.json ' ] ) :
2024-02-14 04:28:51 +01:00
logger . info ( f ' Using tokenizer from: \" { path } \" ' )
2023-07-16 07:21:13 +02:00
else :
2024-02-14 04:28:51 +01:00
logger . error ( " Could not load the model because a tokenizer in Transformers format was not found. " )
2023-07-16 07:21:13 +02:00
return None , None
model = LlamacppHF . from_pretrained ( model_name )
2024-02-14 04:28:51 +01:00
return model
2023-07-16 07:21:13 +02:00
2023-10-05 18:19:18 +02:00
def AutoAWQ_loader ( model_name ) :
2023-10-11 04:03:09 +02:00
from awq import AutoAWQForCausalLM
2023-10-05 18:19:18 +02:00
2023-10-11 04:03:09 +02:00
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-10-05 18:19:18 +02:00
2023-10-11 04:03:09 +02:00
model = AutoAWQForCausalLM . from_quantized (
2024-01-10 01:27:50 +01:00
quant_path = model_dir ,
max_new_tokens = shared . args . max_seq_len ,
trust_remote_code = shared . args . trust_remote_code ,
fuse_layers = not shared . args . no_inject_fused_attention ,
max_memory = get_max_memory_dict ( ) ,
batch_size = 1 ,
safetensors = any ( model_dir . glob ( ' *.safetensors ' ) ) ,
)
2023-10-05 18:19:18 +02:00
2023-10-11 04:03:09 +02:00
return model
2023-10-05 18:19:18 +02:00
2023-08-11 19:41:33 +02:00
2023-12-06 04:01:01 +01:00
def QuipSharp_loader ( model_name ) :
try :
with RelativeImport ( " repositories/quip-sharp " ) :
from lib . utils . unsafe_import import model_from_hf_path
except :
logger . error (
" \n QuIP# has not been found. It must be installed manually for now. \n "
" For instructions on how to do that, please consult: \n "
" https://github.com/oobabooga/text-generation-webui/pull/4803 \n "
)
return None , None
# This fixes duplicate logging messages after the import above.
handlers = logging . getLogger ( ) . handlers
if len ( handlers ) > 1 :
logging . getLogger ( ) . removeHandler ( handlers [ 1 ] )
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if not all ( ( model_dir / file ) . exists ( ) for file in [ ' tokenizer_config.json ' , ' special_tokens_map.json ' , ' tokenizer.model ' ] ) :
logger . error ( f " Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into { model_dir } : special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json. " )
return None , None
model , model_str = model_from_hf_path (
model_dir ,
use_cuda_graph = False ,
use_flash_attn = not shared . args . no_flash_attn
)
return model
2023-05-17 00:52:22 +02:00
def GPTQ_loader ( model_name ) :
# Monkey patch
if shared . args . monkey_patch :
2023-05-30 03:40:54 +02:00
logger . warning ( " Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope. " )
2023-05-17 00:52:22 +02:00
from modules . monkey_patch_gptq_lora import load_model_llama
model , _ = load_model_llama ( model_name )
# No monkey patch
else :
2023-05-17 16:23:13 +02:00
import modules . GPTQ_loader
2023-05-17 00:52:22 +02:00
2023-05-17 16:23:13 +02:00
model = modules . GPTQ_loader . load_quantized ( model_name )
2023-05-17 00:52:22 +02:00
return model
2023-05-17 16:12:12 +02:00
def AutoGPTQ_loader ( model_name ) :
2023-05-17 16:23:13 +02:00
import modules . AutoGPTQ_loader
2023-05-17 16:12:12 +02:00
2023-05-17 16:23:13 +02:00
return modules . AutoGPTQ_loader . load_quantized ( model_name )
2023-05-17 16:12:12 +02:00
2024-02-06 15:21:17 +01:00
def ExLlamav2_loader ( model_name ) :
from modules . exllamav2 import Exllamav2Model
model , tokenizer = Exllamav2Model . from_pretrained ( model_name )
return model , tokenizer
2023-09-12 19:33:07 +02:00
def ExLlamav2_HF_loader ( model_name ) :
from modules . exllamav2_hf import Exllamav2HF
return Exllamav2HF . from_pretrained ( model_name )
2023-12-19 01:23:16 +01:00
def HQQ_loader ( model_name ) :
2023-12-20 16:36:33 +01:00
from hqq . core . quantize import HQQBackend , HQQLinear
from hqq . engine . hf import HQQModelForCausalLM
2023-12-19 01:23:16 +01:00
2024-02-06 17:22:08 +01:00
logger . info ( f " Loading HQQ model with backend: \" { shared . args . hqq_backend } \" " )
2023-12-19 01:23:16 +01:00
model_dir = Path ( f ' { shared . args . model_dir } / { model_name } ' )
model = HQQModelForCausalLM . from_quantized ( str ( model_dir ) )
HQQLinear . set_backend ( getattr ( HQQBackend , shared . args . hqq_backend ) )
return model
2023-05-16 00:38:27 +02:00
def get_max_memory_dict ( ) :
2023-11-16 01:01:54 +01:00
max_memory = { }
2023-11-16 01:04:02 +01:00
max_cpu_memory = shared . args . cpu_memory . strip ( ) if shared . args . cpu_memory is not None else ' 99GiB '
2023-05-16 00:38:27 +02:00
if shared . args . gpu_memory :
memory_map = list ( map ( lambda x : x . strip ( ) , shared . args . gpu_memory ) )
for i in range ( len ( memory_map ) ) :
max_memory [ i ] = f ' { memory_map [ i ] } GiB ' if not re . match ( ' .*ib$ ' , memory_map [ i ] . lower ( ) ) else memory_map [ i ]
2023-11-16 01:01:54 +01:00
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-16 00:38:27 +02:00
# If --auto-devices is provided standalone, try to get a reasonable value
# for the maximum memory of device :0
elif shared . args . auto_devices :
2023-10-27 04:39:51 +02:00
if is_xpu_available ( ) :
total_mem = ( torch . xpu . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
else :
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
2023-11-16 01:00:51 +01:00
2023-05-16 00:38:27 +02:00
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
if total_mem - suggestion < 800 :
suggestion - = 1000
suggestion = int ( round ( suggestion / 1000 ) )
2023-05-22 03:42:34 +02:00
logger . warning ( f " Auto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. You can manually set other values. " )
2023-11-16 01:04:02 +01:00
max_memory [ 0 ] = f ' { suggestion } GiB '
max_memory [ ' cpu ' ] = f ' { max_cpu_memory } GiB ' if not re . match ( ' .*ib$ ' , max_cpu_memory . lower ( ) ) else max_cpu_memory
2023-05-16 00:38:27 +02:00
2023-05-17 16:12:12 +02:00
return max_memory if len ( max_memory ) > 0 else None
2023-05-16 00:38:27 +02:00
2023-04-08 02:36:04 +02:00
def clear_torch_cache ( ) :
gc . collect ( )
if not shared . args . cpu :
2023-10-27 04:39:51 +02:00
if is_xpu_available ( ) :
torch . xpu . empty_cache ( )
else :
torch . cuda . empty_cache ( )
2023-04-08 02:36:04 +02:00
def unload_model ( ) :
shared . model = shared . tokenizer = None
2023-12-23 02:18:24 +01:00
shared . model_name = ' None '
2023-07-03 22:39:06 +02:00
shared . lora_names = [ ]
2023-07-12 20:29:43 +02:00
shared . model_dirty_from_training = False
2023-04-08 02:36:04 +02:00
clear_torch_cache ( )
def reload_model ( ) :
2023-04-08 02:37:41 +02:00
unload_model ( )
2023-04-08 02:36:04 +02:00
shared . model , shared . tokenizer = load_model ( shared . model_name )