text-generation-webui/convert-to-flexgen.py

64 lines
2.0 KiB
Python
Raw Normal View History

2023-02-22 01:00:06 +01:00
'''
Converts a transformers model to a format compatible with flexgen.
'''
2023-02-23 18:41:42 +01:00
2023-02-22 01:00:06 +01:00
import argparse
import os
from pathlib import Path
import numpy as np
2023-02-22 01:00:06 +01:00
import torch
from tqdm import tqdm
2023-02-23 18:41:42 +01:00
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
2023-02-22 01:00:06 +01:00
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args()
2023-02-22 01:00:06 +01:00
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
global torch_linear_init_backup
global torch_layer_norm_init_backup
torch_linear_init_backup = torch.nn.Linear.reset_parameters
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
2023-02-22 01:00:06 +01:00
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
2023-02-22 01:00:06 +01:00
if __name__ == '__main__':
path = Path(args.MODEL)
model_name = path.name
print(f"Loading {model_name}...")
# disable_torch_init()
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# restore_torch_init()
2023-02-22 01:00:06 +01:00
tokenizer = AutoTokenizer.from_pretrained(path)
out_folder = Path(f"models/{model_name}-np")
if not Path(out_folder).exists():
os.mkdir(out_folder)
print(f"Saving the converted model to {out_folder}...")
for name, param in tqdm(list(model.model.named_parameters())):
name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
param_path = os.path.join(out_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())