From a6f476077235b18127167c55b15095a8b19b7830 Mon Sep 17 00:00:00 2001 From: 81300 <105078168+81300@users.noreply.github.com> Date: Wed, 1 Feb 2023 20:22:07 +0200 Subject: [PATCH] Add arg for bfloat16 --- server.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 692c84b9..ccfe033b 100644 --- a/server.py +++ b/server.py @@ -37,6 +37,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--nvme-offload-dir', type=str, help='Directory to use for DeepSpeed ZeRO-3 NVME offloading.') +parser.add_argument('--bf16', action='store_true', help='Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--local_rank', type=int, default=0, help='Optional argument for DeepSpeed distributed setups.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') @@ -92,14 +93,20 @@ if args.deepspeed: # DeepSpeed configration # https://huggingface.co/docs/transformers/main_classes/deepspeed + if args.bf16: + ds_fp16 = False + ds_bf16 = True + else: + ds_fp16 = True + ds_bf16 = False train_batch_size = 1 * world_size if args.nvme_offload_dir: ds_config = { "fp16": { - "enabled": True, + "enabled": ds_fp16, }, "bf16": { - "enabled": False, + "enabled": ds_bf16, }, "zero_optimization": { "stage": 3, @@ -135,10 +142,10 @@ if args.deepspeed: else: ds_config = { "fp16": { - "enabled": True, + "enabled": ds_fp16, }, "bf16": { - "enabled": False, + "enabled": ds_bf16, }, "zero_optimization": { "stage": 3, @@ -178,7 +185,10 @@ def load_model(model_name): # DeepSpeed ZeRO-3 elif args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}")) + if args.bf16: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16) + else: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None,