mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Make --trust-remote-code work for all models (#1772)
This commit is contained in:
parent
0e6d17304a
commit
bd531c2dc2
@ -57,7 +57,7 @@ def find_model_type(model_name):
|
|||||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||||
return 'gpt4chan'
|
return 'gpt4chan'
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'))
|
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
|
||||||
# Not a "catch all", but fairly accurate
|
# Not a "catch all", but fairly accurate
|
||||||
if config.to_dict().get("is_encoder_decoder", False):
|
if config.to_dict().get("is_encoder_decoder", False):
|
||||||
return 'HF_seq2seq'
|
return 'HF_seq2seq'
|
||||||
@ -70,15 +70,13 @@ def load_model(model_name):
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.model_type = find_model_type(model_name)
|
shared.model_type = find_model_type(model_name)
|
||||||
|
trust_remote_code = shared.args.trust_remote_code
|
||||||
if shared.model_type == 'chatglm':
|
if shared.model_type == 'chatglm':
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
trust_remote_code = shared.args.trust_remote_code
|
|
||||||
elif shared.model_type == 'HF_seq2seq':
|
elif shared.model_type == 'HF_seq2seq':
|
||||||
LoaderClass = AutoModelForSeq2SeqLM
|
LoaderClass = AutoModelForSeq2SeqLM
|
||||||
trust_remote_code = False
|
|
||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
trust_remote_code = False
|
|
||||||
|
|
||||||
# Load the model in simple 16-bit mode by default
|
# Load the model in simple 16-bit mode by default
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
||||||
|
Loading…
Reference in New Issue
Block a user