From a2a3e8f797b2c6b9e7a5413886555df754cff9e6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 1 Mar 2023 20:02:48 -0300 Subject: [PATCH] Add --rwkv-strategy parameter --- modules/RWKV.py | 5 ++++- modules/shared.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index ee8d76fb..aa4a0b91 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -25,7 +25,10 @@ class RWKVModel: def from_pretrained(self, path, dtype="fp16", device="cuda"): tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") - model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}') + if shared.args.rwkv_strategy is None: + model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}') + else: + model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy) pipeline = PIPELINE(model, os.path.abspath(tokenizer_path)) result = self() diff --git a/modules/shared.py b/modules/shared.py index ec1bd521..d59c1344 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -63,6 +63,7 @@ parser.add_argument("--compress-weight", action="store_true", help="FlexGen: act parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') +parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')