Add a note about RWKV loader

This commit is contained in:
oobabooga 2023-09-26 17:43:39 -07:00
parent 13a54729b1
commit 87ea2d96fd
2 changed files with 17 additions and 8 deletions

View File

@ -1,3 +1,8 @@
'''
This loader is not currently maintained as RWKV can now be loaded
through the transformers library.
'''
import copy import copy
import os import os
from pathlib import Path from pathlib import Path

View File

@ -211,14 +211,6 @@ def huggingface_loader(model_name):
return model return model
def RWKV_loader(model_name):
from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer
def llamacpp_loader(model_name): def llamacpp_loader(model_name):
from modules.llamacpp_model import LlamaCppModel from modules.llamacpp_model import LlamaCppModel
@ -335,6 +327,18 @@ def ExLlamav2_HF_loader(model_name):
return Exllamav2HF.from_pretrained(model_name) return Exllamav2HF.from_pretrained(model_name)
def RWKV_loader(model_name):
'''
This loader is not currently maintained as RWKV can now be loaded
through the transformers library.
'''
from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer
def get_max_memory_dict(): def get_max_memory_dict():
max_memory = {} max_memory = {}
if shared.args.gpu_memory: if shared.args.gpu_memory: