Fix RWKV tokenizer

This commit is contained in:
oobabooga 2023-03-27 23:42:29 -03:00
parent 036163a751
commit ee95e55df6

View File

@ -90,7 +90,7 @@ def load_model(model_name):
from modules.RWKV import RWKVModel, RWKVTokenizer from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.model_name)) tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer return model, tokenizer