Code reuse + indication

Now shows the message in the console when unloading weights. Also reload_model() calls unload_model() first to free the memory so that multiple reloads won't overfill it.
This commit is contained in:
Φφ 2023-03-21 20:19:38 +03:00
parent 1917b15275
commit 483d173d23

View File

@ -64,9 +64,7 @@ def load_model_wrapper(selected_model):
return selected_model return selected_model
def reload_model(): def reload_model():
if not shared.args.cpu: unload_model()
gc.collect()
torch.cuda.empty_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def unload_model(): def unload_model():
@ -74,6 +72,7 @@ def unload_model():
if not shared.args.cpu: if not shared.args.cpu:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
print("Model weights unloaded.")
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora shared.lora_name = selected_lora