diff --git a/modules/models.py b/modules/models.py index 5aaef800..26a10f7a 100644 --- a/modules/models.py +++ b/modules/models.py @@ -90,7 +90,7 @@ def load_model(model_name): from modules.RWKV import RWKVModel, RWKVTokenizer model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path(shared.model_name)) + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer