Fix bug in models.py

This commit is contained in:
oobabooga 2023-04-26 01:55:40 -03:00 committed by GitHub
parent 4c491aa142
commit a8409426d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,20 +40,20 @@ if shared.args.deepspeed:
def find_model_type(model_name): def find_model_type(model_name):
model_name = model_name.lower() model_name_lower = model_name.lower()
if 'rwkv-' in model_name: if 'rwkv-' in model_name_lower:
return 'rwkv' return 'rwkv'
elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0: elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
return 'llamacpp' return 'llamacpp'
elif re.match('.*ggml.*\.bin', model_name): elif re.match('.*ggml.*\.bin', model_name_lower):
return 'llamacpp' return 'llamacpp'
elif 'chatglm' in model_name: elif 'chatglm' in model_name_lower:
return 'chatglm' return 'chatglm'
elif 'galactica' in model_name: elif 'galactica' in model_name_lower:
return 'galactica' return 'galactica'
elif 'llava' in model_name: elif 'llava' in model_name_lower:
return 'llava' return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])): elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan' return 'gpt4chan'
else: else:
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}") config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")