Don't override user initial wbits/groupsize

This commit is contained in:
oobabooga 2023-04-14 15:24:03 -03:00
parent 64e3b44e0f
commit 43e01282b3

View File

@ -185,9 +185,17 @@ def download_model_wrapper(repo_id):
# Update the command-line arguments based on the interface values
def update_model_parameters(state):
def update_model_parameters(state, initial=False):
elements = ui.list_model_elements() # the names of the parameters
gpu_memories = []
defaults = {
'wbits': 0,
'groupsize': -1,
'cpu_memory': None,
'gpu_memory': None,
'model_type': None,
'pre_layer': 0
}
for i, element in enumerate(elements):
if element not in state:
continue
@ -197,18 +205,20 @@ def update_model_parameters(state):
gpu_memories.append(value)
continue
if element == 'wbits' and value == 'None':
value = 0
if element == 'groupsize' and value == 'None':
value = -1
if initial and eval(f"shared.args.{element}") != defaults[element]:
continue
# Setting null defaults
if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
value = defaults[element]
elif element in ['cpu_memory'] and value == 0:
value = defaults[element]
# Making some simple conversions
if element in ['wbits', 'groupsize', 'pre_layer']:
value = int(value)
if element == 'cpu_memory' and value == 0:
value = None
elif element == 'cpu_memory' and value is not None:
value = f"{value}MiB"
if element == 'model_type' and value == 'None':
value = None
exec(f"shared.args.{element} = value")
@ -217,6 +227,7 @@ def update_model_parameters(state):
if i > 0:
found_positive = True
break
if found_positive:
shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
else:
@ -859,7 +870,7 @@ if __name__ == "__main__":
model_settings = get_model_specific_settings(shared.model_name)
shared.settings.update(model_settings) # hijacking the interface defaults
update_model_parameters(model_settings) # hijacking the command-line arguments
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
# Load the model
shared.model, shared.tokenizer = load_model(shared.model_name)