Remove unnecessary shared.previous_model_name variable

This commit is contained in:
oobabooga 2024-07-28 18:35:11 -07:00
parent addcb52c56
commit 9dcff21da9
4 changed files with 4 additions and 6 deletions

View File

@ -13,8 +13,8 @@ global_scores = None
def get_next_logits(*args, **kwargs): def get_next_logits(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
needs_lock = not args[2] # use_samplers needs_lock = not args[2] # use_samplers
if needs_lock: if needs_lock:

View File

@ -370,7 +370,6 @@ def clear_torch_cache():
def unload_model(keep_model_name=False): def unload_model(keep_model_name=False):
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
shared.previous_model_name = shared.model_name
shared.lora_names = [] shared.lora_names = []
shared.model_dirty_from_training = False shared.model_dirty_from_training = False
clear_torch_cache() clear_torch_cache()

View File

@ -13,7 +13,6 @@ from modules.logging_colors import logger
model = None model = None
tokenizer = None tokenizer = None
model_name = 'None' model_name = 'None'
previous_model_name = 'None'
is_seq2seq = False is_seq2seq = False
model_dirty_from_training = False model_dirty_from_training = False
lora_names = [] lora_names = []

View File

@ -32,8 +32,8 @@ from modules.models import clear_torch_cache, load_model
def generate_reply(*args, **kwargs): def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
shared.generation_lock.acquire() shared.generation_lock.acquire()
try: try: