mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Make FlexGen work with the newest API
This commit is contained in:
parent
48b83c9a70
commit
8e3e8a070f
@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
|
||||
TorchDevice, TorchDisk, TorchMixedDevice,
|
||||
get_opt_config)
|
||||
from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
|
||||
Policy, str2bool)
|
||||
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
@ -48,10 +47,8 @@ def load_model(model_name):
|
||||
|
||||
# FlexGen
|
||||
elif shared.args.flexgen:
|
||||
gpu = TorchDevice("cuda:0")
|
||||
cpu = TorchDevice("cpu")
|
||||
disk = TorchDisk(shared.args.disk_cache_dir)
|
||||
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
|
||||
# Initialize environment
|
||||
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
||||
|
||||
# Offloading policy
|
||||
policy = Policy(1, 1,
|
||||
@ -69,9 +66,7 @@ def load_model(model_name):
|
||||
num_bits=4, group_size=64,
|
||||
group_dim=2, symmetric=False))
|
||||
|
||||
opt_config = get_opt_config(f"facebook/{shared.model_name}")
|
||||
model = OptLM(opt_config, env, "models", policy)
|
||||
model.init_all_weights()
|
||||
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
|
Loading…
Reference in New Issue
Block a user