mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
28 lines
763 B
Python
28 lines
763 B
Python
'''
|
|
Converts a transformers model to .pt, which is faster to load.
|
|
|
|
Example:
|
|
python convert.py models/opt-1.3b
|
|
|
|
Output will be written to torch-dumps/name-of-the-model.pt
|
|
'''
|
|
|
|
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
|
|
import torch
|
|
from sys import argv
|
|
|
|
path = argv[1]
|
|
if path[-1] != '/':
|
|
path = path+'/'
|
|
model_name = path.split('/')[-2]
|
|
|
|
print(f"Loading {model_name}...")
|
|
if model_name in ['flan-t5', 't5-large']:
|
|
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
|
print("Model loaded.")
|
|
|
|
print(f"Saving to torch-dumps/{model_name}.pt")
|
|
torch.save(model, f"torch-dumps/{model_name}.pt")
|