2022-12-21 13:28:19 -03:00
|
|
|
'''
|
|
|
|
Converts a transformers model to .pt, which is faster to load.
|
|
|
|
|
2023-01-07 00:04:52 -03:00
|
|
|
Example:
|
2023-01-07 16:54:49 -03:00
|
|
|
python convert-to-torch.py models/opt-1.3b
|
2022-12-21 13:28:19 -03:00
|
|
|
|
2023-01-07 16:54:49 -03:00
|
|
|
The output will be written to torch-dumps/name-of-the-model.pt
|
2022-12-21 13:28:19 -03:00
|
|
|
'''
|
2023-02-10 15:57:55 -03:00
|
|
|
|
2023-01-07 16:33:43 -03:00
|
|
|
from pathlib import Path
|
2023-02-10 15:57:55 -03:00
|
|
|
from sys import argv
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from transformers import AutoModelForCausalLM
|
2022-12-21 13:28:19 -03:00
|
|
|
|
2023-01-07 16:33:43 -03:00
|
|
|
path = Path(argv[1])
|
|
|
|
model_name = path.name
|
2023-01-07 00:04:52 -03:00
|
|
|
|
|
|
|
print(f"Loading {model_name}...")
|
2023-01-10 23:41:35 -03:00
|
|
|
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
2023-01-16 16:35:45 -03:00
|
|
|
print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt")
|
2023-01-07 16:33:43 -03:00
|
|
|
torch.save(model, Path(f"torch-dumps/{model_name}.pt"))
|