mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Make --trust-remote-code work for all models (#1772)
This commit is contained in:
parent
0e6d17304a
commit
bd531c2dc2
@ -57,7 +57,7 @@ def find_model_type(model_name):
|
||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'))
|
||||
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
@ -70,15 +70,13 @@ def load_model(model_name):
|
||||
t0 = time.time()
|
||||
|
||||
shared.model_type = find_model_type(model_name)
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
if shared.model_type == 'chatglm':
|
||||
LoaderClass = AutoModel
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
elif shared.model_type == 'HF_seq2seq':
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
trust_remote_code = False
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
trust_remote_code = False
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
||||
|
Loading…
Reference in New Issue
Block a user