mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Clean the convert to torch script
This commit is contained in:
parent
c7b29668a2
commit
898e12058e
@ -1,38 +1,27 @@
|
|||||||
'''
|
'''
|
||||||
Converts a transformers model to .pt, which is faster to load.
|
Converts a transformers model to .pt, which is faster to load.
|
||||||
|
|
||||||
Run with python convert.py /path/to/model/
|
Example:
|
||||||
Make sure to write /path/to/model/ with a trailing / and not
|
python convert.py models/opt-1.3b
|
||||||
/path/to/model
|
|
||||||
|
|
||||||
Output will be written to torch-dumps/name-of-the-model.pt
|
Output will be written to torch-dumps/name-of-the-model.pt
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
|
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
|
||||||
from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
|
||||||
from sys import argv
|
from sys import argv
|
||||||
import time
|
|
||||||
import glob
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
path = argv[1]
|
||||||
|
if path[-1] != '/':
|
||||||
|
path = path+'/'
|
||||||
|
model_name = path.split('/')[-2]
|
||||||
|
|
||||||
if argv[1].endswith('pt'):
|
print(f"Loading {model_name}...")
|
||||||
model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
|
if model_name in ['flan-t5', 't5-large']:
|
||||||
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
|
||||||
elif 'galactica' in argv[1].lower():
|
|
||||||
model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
|
||||||
#model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
|
|
||||||
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
||||||
elif 'flan-t5' in argv[1].lower():
|
|
||||||
model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
|
||||||
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
||||||
else:
|
else:
|
||||||
print("Loading the model")
|
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
||||||
model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
print("Model loaded.")
|
||||||
print("Model loaded")
|
|
||||||
#model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
|
|
||||||
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
|
|
||||||
|
|
||||||
|
print(f"Saving to torch-dumps/{model_name}.pt")
|
||||||
|
torch.save(model, f"torch-dumps/{model_name}.pt")
|
||||||
|
Loading…
Reference in New Issue
Block a user