diff --git a/modules/models.py b/modules/models.py index 01201eee..24fafb37 100644 --- a/modules/models.py +++ b/modules/models.py @@ -204,9 +204,9 @@ def load_model(model_name): checkpoint = Path(f'{shared.args.model_dir}/{model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': - config = AutoConfig.from_pretrained(checkpoint) + config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=trust_remote_code) with init_empty_weights(): - model = LoaderClass.from_config(config) + model = LoaderClass.from_config(config, trust_remote_code=trust_remote_code) model.tie_weights() params['device_map'] = infer_auto_device_map(