From dcfd866402dfbbc849bd4441fd1de9448de18c75 Mon Sep 17 00:00:00 2001 From: EyeDeck Date: Thu, 23 Mar 2023 21:31:34 -0400 Subject: [PATCH 1/3] Allow loading of .safetensors through GPTQ-for-LLaMa --- modules/GPTQ_loader.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 32a5458f..bec6c66f 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -37,21 +37,23 @@ def load_quantized(model_name): path_to_model = Path(f'models/{model_name}') if path_to_model.name.lower().startswith('llama-7b'): - pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-7b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-13b'): - pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-13b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-30b'): - pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-30b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-65b'): - pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-65b-{shared.args.gptq_bits}bit' else: - pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' + pt_model = f'{model_name}-{shared.args.gptq_bits}bit' - # Try to find the .pt both in models/ and in the subfolder + # Try to find the .safetensors or .pt both in models/ and in the subfolder pt_path = None - for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: + for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: if path.exists(): + print(f"Found {path}") pt_path = path + break if not pt_path: print(f"Could not find {pt_model}, exiting...") From 25be9698c74d7af950cbcbf8ec4c0cd9bebc6d3c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Mar 2023 01:18:32 -0300 Subject: [PATCH 2/3] Fix LoRA on mps --- modules/LoRA.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index aa68ad32..283fcf4c 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch + import modules.shared as shared from modules.models import load_model from modules.text_generation import clear_torch_cache @@ -34,4 +36,8 @@ def add_lora_to_model(lora_name): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - shared.model.cuda() + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda() From 70f9565f37c47be34d4bdbabe3c874bc4c4c7039 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Mar 2023 02:35:30 -0300 Subject: [PATCH 3/3] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4e4959ac..60444401 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. -* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). +* [LLaMA model, including 4-bit GPTQ support](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs). * Supports softprompts.