mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 06:00:15 +01:00
Remove T5 support
This commit is contained in:
parent
b2a2ddcb15
commit
18ae08ef91
@ -7,7 +7,7 @@ python convert-to-torch.py models/opt-1.3b
|
||||
The output will be written to torch-dumps/name-of-the-model.pt
|
||||
'''
|
||||
|
||||
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
from sys import argv
|
||||
from pathlib import Path
|
||||
@ -16,10 +16,7 @@ path = Path(argv[1])
|
||||
model_name = path.name
|
||||
|
||||
print(f"Loading {model_name}...")
|
||||
if model_name in ['flan-t5', 't5-large']:
|
||||
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
||||
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
||||
print("Model loaded.")
|
||||
|
||||
print(f"Saving to torch-dumps/{model_name}.pt")
|
||||
|
Loading…
Reference in New Issue
Block a user