text-generation-webui/convert-to-torch.py

28 lines
763 B
Python
Raw Normal View History

2022-12-21 17:28:19 +01:00
'''
Converts a transformers model to .pt, which is faster to load.
2023-01-07 04:04:52 +01:00
Example:
python convert.py models/opt-1.3b
2022-12-21 17:28:19 +01:00
Output will be written to torch-dumps/name-of-the-model.pt
'''
2023-01-07 04:04:52 +01:00
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
2022-12-21 17:28:19 +01:00
import torch
from sys import argv
2023-01-07 04:04:52 +01:00
path = argv[1]
if path[-1] != '/':
path = path+'/'
model_name = path.split('/')[-2]
print(f"Loading {model_name}...")
if model_name in ['flan-t5', 't5-large']:
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
2022-12-21 17:28:19 +01:00
else:
2023-01-07 04:04:52 +01:00
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
print("Model loaded.")
2022-12-21 17:28:19 +01:00
2023-01-07 04:04:52 +01:00
print(f"Saving to torch-dumps/{model_name}.pt")
torch.save(model, f"torch-dumps/{model_name}.pt")