Make FlexGen work with the newest API

This commit is contained in:
oobabooga 2023-02-26 16:53:41 -03:00
parent 48b83c9a70
commit 8e3e8a070f

View File

@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
TorchDevice, TorchDisk, TorchMixedDevice,
get_opt_config)
from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
Policy, str2bool)
if shared.args.deepspeed:
import deepspeed
@ -48,10 +47,8 @@ def load_model(model_name):
# FlexGen
elif shared.args.flexgen:
gpu = TorchDevice("cuda:0")
cpu = TorchDevice("cpu")
disk = TorchDisk(shared.args.disk_cache_dir)
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
# Initialize environment
env = ExecutionEnv.create(shared.args.disk_cache_dir)
# Offloading policy
policy = Policy(1, 1,
@ -69,9 +66,7 @@ def load_model(model_name):
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
opt_config = get_opt_config(f"facebook/{shared.model_name}")
model = OptLM(opt_config, env, "models", policy)
model.init_all_weights()
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed: