Cleanup: set shared.model_name only once

This commit is contained in:
oobabooga 2023-12-08 06:35:23 -08:00
parent 62d59a516f
commit 2a335b8aa7
4 changed files with 3 additions and 5 deletions

View File

@ -55,7 +55,6 @@ def _load_model(data):
setattr(shared.args, k, args[k]) setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name) shared.model, shared.tokenizer = load_model(model_name)
shared.model_name = model_name
# Update shared.settings with custom generation defaults # Update shared.settings with custom generation defaults
if settings: if settings:

View File

@ -69,9 +69,8 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
model_settings = get_model_metadata(model) model_settings = get_model_metadata(model)
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
update_model_parameters(model_settings) # hijacking the command-line arguments update_model_parameters(model_settings) # hijacking the command-line arguments
shared.model_name = model
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(model)
except: except:
cumulative_log += f"Failed to load `{model}`. Moving on.\n\n" cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
yield cumulative_log yield cumulative_log

View File

@ -58,6 +58,7 @@ def load_model(model_name, loader=None):
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False
shared.model_name = model_name
load_func_map = { load_func_map = {
'Transformers': huggingface_loader, 'Transformers': huggingface_loader,
'AutoGPTQ': AutoGPTQ_loader, 'AutoGPTQ': AutoGPTQ_loader,

View File

@ -203,10 +203,9 @@ def load_model_wrapper(selected_model, loader, autoload=False):
else: else:
try: try:
yield f"Loading `{selected_model}`..." yield f"Loading `{selected_model}`..."
shared.model_name = selected_model
unload_model() unload_model()
if selected_model != '': if selected_model != '':
shared.model, shared.tokenizer = load_model(shared.model_name, loader) shared.model, shared.tokenizer = load_model(selected_model, loader)
if shared.model is not None: if shared.model is not None:
output = f"Successfully loaded `{selected_model}`." output = f"Successfully loaded `{selected_model}`."