mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Fix deepspeed (oops)
This commit is contained in:
parent
90f1067598
commit
f38c9bf428
@ -38,7 +38,7 @@ parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to
|
||||
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
|
||||
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
|
||||
parser.add_argument('--bf16', action='store_true', help='DeepSpeed: Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
parser.add_argument('--local-rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
|
||||
@ -80,7 +80,7 @@ if args.settings is not None and Path(args.settings).exists():
|
||||
if args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
|
||||
from modules.deepseed_config import generate_ds_config
|
||||
from modules.deepspeed_parameters import generate_ds_config
|
||||
|
||||
# Distributed setup
|
||||
if args.local_rank is not None:
|
||||
@ -90,7 +90,7 @@ if args.deepspeed:
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
deepspeed.init_distributed()
|
||||
ds_config = generate_ds_config(args.bf16, 1 * world_size, nvme_offload_dir)
|
||||
ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
def load_model(model_name):
|
||||
|
Loading…
Reference in New Issue
Block a user