diff --git a/server.py b/server.py index 56bb499d..db83b4f3 100644 --- a/server.py +++ b/server.py @@ -50,12 +50,17 @@ def get_available_softprompts(): def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) +def unload_model(): + shared.model = shared.tokenizer = None + clear_torch_cache() + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model - shared.model = shared.tokenizer = None - clear_torch_cache() - shared.model, shared.tokenizer = load_model(shared.model_name) + + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model