Simplify GPTQ_loader.py

This commit is contained in:
oobabooga 2023-04-13 12:13:07 -03:00
parent c13e8651ad
commit a75e02de4d

View File

@ -111,12 +111,10 @@ def load_quantized(model_name):
pt_path = None pt_path = None
priority_name_list = [ priority_name_list = [
Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}') Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
for ext in ['.safetensors', '.pt']
for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else ['']) for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
for ext in ['.safetensors', '.pt']
for hyphen in ['-', f'/{model_name}-', '/'] for hyphen in ['-', f'/{model_name}-', '/']
] ]
if shared.args.groupsize > 0:
priority_name_list = [i for i in priority_name_list if str(shared.args.groupsize) in i.name] + [i for i in priority_name_list if str(shared.args.groupsize) not in i.name]
for path in priority_name_list: for path in priority_name_list:
if path.exists(): if path.exists():
pt_path = path pt_path = path