mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Make FlexGen work with the newest API
This commit is contained in:
parent
48b83c9a70
commit
8e3e8a070f
@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
|
|||||||
local_rank = None
|
local_rank = None
|
||||||
|
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
|
from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
|
||||||
TorchDevice, TorchDisk, TorchMixedDevice,
|
Policy, str2bool)
|
||||||
get_opt_config)
|
|
||||||
|
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
@ -48,10 +47,8 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# FlexGen
|
# FlexGen
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
gpu = TorchDevice("cuda:0")
|
# Initialize environment
|
||||||
cpu = TorchDevice("cpu")
|
env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
||||||
disk = TorchDisk(shared.args.disk_cache_dir)
|
|
||||||
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
|
|
||||||
|
|
||||||
# Offloading policy
|
# Offloading policy
|
||||||
policy = Policy(1, 1,
|
policy = Policy(1, 1,
|
||||||
@ -69,9 +66,7 @@ def load_model(model_name):
|
|||||||
num_bits=4, group_size=64,
|
num_bits=4, group_size=64,
|
||||||
group_dim=2, symmetric=False))
|
group_dim=2, symmetric=False))
|
||||||
|
|
||||||
opt_config = get_opt_config(f"facebook/{shared.model_name}")
|
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
|
||||||
model = OptLM(opt_config, env, "models", policy)
|
|
||||||
model.init_all_weights()
|
|
||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
|
Loading…
Reference in New Issue
Block a user