diff --git a/README.md b/README.md index 4e4959ac..60444401 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. -* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). +* [LLaMA model, including 4-bit GPTQ support](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs). * Supports softprompts. diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 32a5458f..bec6c66f 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -37,21 +37,23 @@ def load_quantized(model_name): path_to_model = Path(f'models/{model_name}') if path_to_model.name.lower().startswith('llama-7b'): - pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-7b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-13b'): - pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-13b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-30b'): - pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-30b-{shared.args.gptq_bits}bit' elif path_to_model.name.lower().startswith('llama-65b'): - pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt' + pt_model = f'llama-65b-{shared.args.gptq_bits}bit' else: - pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' + pt_model = f'{model_name}-{shared.args.gptq_bits}bit' - # Try to find the .pt both in models/ and in the subfolder + # Try to find the .safetensors or .pt both in models/ and in the subfolder pt_path = None - for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: + for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: if path.exists(): + print(f"Found {path}") pt_path = path + break if not pt_path: print(f"Could not find {pt_model}, exiting...") diff --git a/modules/LoRA.py b/modules/LoRA.py index 394f7367..8d608485 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch + import modules.shared as shared from modules.models import load_model from modules.text_generation import clear_torch_cache @@ -34,4 +36,8 @@ def add_lora_to_model(lora_name): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - shared.model.cuda() + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda()