Make FlexGen work with the newest API

This commit is contained in:
oobabooga 2023-02-26 16:53:41 -03:00
parent 48b83c9a70
commit 8e3e8a070f

View File

@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
local_rank = None local_rank = None
if shared.args.flexgen: if shared.args.flexgen:
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy, from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
TorchDevice, TorchDisk, TorchMixedDevice, Policy, str2bool)
get_opt_config)
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
@ -48,10 +47,8 @@ def load_model(model_name):
# FlexGen # FlexGen
elif shared.args.flexgen: elif shared.args.flexgen:
gpu = TorchDevice("cuda:0") # Initialize environment
cpu = TorchDevice("cpu") env = ExecutionEnv.create(shared.args.disk_cache_dir)
disk = TorchDisk(shared.args.disk_cache_dir)
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
# Offloading policy # Offloading policy
policy = Policy(1, 1, policy = Policy(1, 1,
@ -69,9 +66,7 @@ def load_model(model_name):
num_bits=4, group_size=64, num_bits=4, group_size=64,
group_dim=2, symmetric=False)) group_dim=2, symmetric=False))
opt_config = get_opt_config(f"facebook/{shared.model_name}") model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
model = OptLM(opt_config, env, "models", policy)
model.init_all_weights()
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed: