Better detect EXL2 models

This commit is contained in:
oobabooga 2023-09-23 13:04:27 -07:00
parent 895ec9dadb
commit 7a3ca2c68f
2 changed files with 3 additions and 1 deletions

View File

@ -76,6 +76,8 @@ def infer_loader(model_name, model_settings):
loader = 'llama.cpp' loader = 'llama.cpp'
elif re.match(r'.*rwkv.*\.pth', model_name.lower()): elif re.match(r'.*rwkv.*\.pth', model_name.lower()):
loader = 'RWKV' loader = 'RWKV'
elif re.match(r'.*exl2', model_name.lower()):
loader = 'ExLlamav2_HF'
else: else:
loader = 'Transformers' loader = 'Transformers'

View File

@ -251,7 +251,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state): def update_truncation_length(current_length, state):
if state['loader'] in ['ExLlama', 'ExLlama_HF']: if state['loader'].lower().startswith('exllama'):
return state['max_seq_len'] return state['max_seq_len']
elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
return state['n_ctx'] return state['n_ctx']