Use argparse defaults

This commit is contained in:
oobabooga 2023-04-14 15:35:06 -03:00
parent 43e01282b3
commit 3a337cfded
2 changed files with 5 additions and 11 deletions

View File

@ -147,6 +147,7 @@ parser.add_argument('--auto-launch', action='store_true', default=False, help='O
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
args = parser.parse_args() args = parser.parse_args()
args_defaults = parser.parse_args([])
# Deprecation warnings for parameters that have been renamed # Deprecation warnings for parameters that have been renamed
deprecated_dict = {} deprecated_dict = {}

View File

@ -188,14 +188,7 @@ def download_model_wrapper(repo_id):
def update_model_parameters(state, initial=False): def update_model_parameters(state, initial=False):
elements = ui.list_model_elements() # the names of the parameters elements = ui.list_model_elements() # the names of the parameters
gpu_memories = [] gpu_memories = []
defaults = {
'wbits': 0,
'groupsize': -1,
'cpu_memory': None,
'gpu_memory': None,
'model_type': None,
'pre_layer': 0
}
for i, element in enumerate(elements): for i, element in enumerate(elements):
if element not in state: if element not in state:
continue continue
@ -205,14 +198,14 @@ def update_model_parameters(state, initial=False):
gpu_memories.append(value) gpu_memories.append(value)
continue continue
if initial and eval(f"shared.args.{element}") != defaults[element]: if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
continue continue
# Setting null defaults # Setting null defaults
if element in ['wbits', 'groupsize', 'model_type'] and value == 'None': if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
value = defaults[element] value = vars(shared.args_defaults)[element]
elif element in ['cpu_memory'] and value == 0: elif element in ['cpu_memory'] and value == 0:
value = defaults[element] value = vars(shared.args_defaults)[element]
# Making some simple conversions # Making some simple conversions
if element in ['wbits', 'groupsize', 'pre_layer']: if element in ['wbits', 'groupsize', 'pre_layer']: