mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Add --rwkv-strategy parameter
This commit is contained in:
parent
99dc95e14e
commit
a2a3e8f797
@ -25,7 +25,10 @@ class RWKVModel:
|
||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||
|
||||
model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
|
||||
if shared.args.rwkv_strategy is None:
|
||||
model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
|
||||
else:
|
||||
model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
|
||||
pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
|
||||
|
||||
result = self()
|
||||
|
@ -63,6 +63,7 @@ parser.add_argument("--compress-weight", action="store_true", help="FlexGen: act
|
||||
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
|
||||
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
|
||||
parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||
|
Loading…
Reference in New Issue
Block a user