diff --git a/modules/models.py b/modules/models.py index 0cb9ae6e..1264a58c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error() local_rank = None if shared.args.flexgen: - from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy, - TorchDevice, TorchDisk, TorchMixedDevice, - get_opt_config) + from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM, + Policy, str2bool) if shared.args.deepspeed: import deepspeed @@ -48,10 +47,8 @@ def load_model(model_name): # FlexGen elif shared.args.flexgen: - gpu = TorchDevice("cuda:0") - cpu = TorchDevice("cpu") - disk = TorchDisk(shared.args.disk_cache_dir) - env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + # Initialize environment + env = ExecutionEnv.create(shared.args.disk_cache_dir) # Offloading policy policy = Policy(1, 1, @@ -69,9 +66,7 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - opt_config = get_opt_config(f"facebook/{shared.model_name}") - model = OptLM(opt_config, env, "models", policy) - model.init_all_weights() + model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: