diff --git a/server.py b/server.py index 39406339..d46ffc30 100644 --- a/server.py +++ b/server.py @@ -185,9 +185,17 @@ def download_model_wrapper(repo_id): # Update the command-line arguments based on the interface values -def update_model_parameters(state): +def update_model_parameters(state, initial=False): elements = ui.list_model_elements() # the names of the parameters gpu_memories = [] + defaults = { + 'wbits': 0, + 'groupsize': -1, + 'cpu_memory': None, + 'gpu_memory': None, + 'model_type': None, + 'pre_layer': 0 + } for i, element in enumerate(elements): if element not in state: continue @@ -197,18 +205,20 @@ def update_model_parameters(state): gpu_memories.append(value) continue - if element == 'wbits' and value == 'None': - value = 0 - if element == 'groupsize' and value == 'None': - value = -1 + if initial and eval(f"shared.args.{element}") != defaults[element]: + continue + + # Setting null defaults + if element in ['wbits', 'groupsize', 'model_type'] and value == 'None': + value = defaults[element] + elif element in ['cpu_memory'] and value == 0: + value = defaults[element] + + # Making some simple conversions if element in ['wbits', 'groupsize', 'pre_layer']: value = int(value) - if element == 'cpu_memory' and value == 0: - value = None elif element == 'cpu_memory' and value is not None: value = f"{value}MiB" - if element == 'model_type' and value == 'None': - value = None exec(f"shared.args.{element} = value") @@ -217,6 +227,7 @@ def update_model_parameters(state): if i > 0: found_positive = True break + if found_positive: shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories] else: @@ -859,7 +870,7 @@ if __name__ == "__main__": model_settings = get_model_specific_settings(shared.model_name) shared.settings.update(model_settings) # hijacking the interface defaults - update_model_parameters(model_settings) # hijacking the command-line arguments + update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments # Load the model shared.model, shared.tokenizer = load_model(shared.model_name)