Add bf16 back here (the fp16 -> bf16 conversion takes a few seconds)

This commit is contained in:
oobabooga 2023-02-21 00:54:53 -03:00
parent bc856eb962
commit e52b697d5a

View File

@ -23,6 +23,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).') parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).") parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
args = parser.parse_args() args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
@ -30,7 +31,7 @@ if __name__ == '__main__':
model_name = path.name model_name = path.name
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
tokenizer = AutoTokenizer.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path)
out_folder = args.output or Path(f"models/{model_name}_safetensors") out_folder = args.output or Path(f"models/{model_name}_safetensors")