Clean the convert to torch script

This commit is contained in:
oobabooga 2023-01-07 00:04:52 -03:00
parent c7b29668a2
commit 898e12058e

View File

@ -1,38 +1,27 @@
''' '''
Converts a transformers model to .pt, which is faster to load. Converts a transformers model to .pt, which is faster to load.
Run with python convert.py /path/to/model/ Example:
Make sure to write /path/to/model/ with a trailing / and not python convert.py models/opt-1.3b
/path/to/model
Output will be written to torch-dumps/name-of-the-model.pt Output will be written to torch-dumps/name-of-the-model.pt
''' '''
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
import torch import torch
import sys
from sys import argv from sys import argv
import time
import glob
import psutil
print(f"torch-dumps/{argv[1].split('/')[-2]}.pt") path = argv[1]
if path[-1] != '/':
if argv[1].endswith('pt'): path = path+'/'
model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto") model_name = path.split('/')[-2]
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
elif 'galactica' in argv[1].lower():
model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
#model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
elif 'flan-t5' in argv[1].lower():
model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
else:
print("Loading the model")
model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
print("Model loaded")
#model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
print(f"Loading {model_name}...")
if model_name in ['flan-t5', 't5-large']:
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
else:
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
print("Model loaded.")
print(f"Saving to torch-dumps/{model_name}.pt")
torch.save(model, f"torch-dumps/{model_name}.pt")