From 1141987a0d499a0b59a47f819df4edeb195d54d5 Mon Sep 17 00:00:00 2001 From: jllllll <3887729+jllllll@users.noreply.github.com> Date: Mon, 24 Jul 2023 09:25:36 -0500 Subject: [PATCH] Add checks for ROCm and unsupported architectures to llama_cpp_cuda loading (#3225) --- modules/llamacpp_hf.py | 8 ++++++-- modules/llamacpp_model.py | 15 +++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index e09c1a74..069e1007 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -10,11 +10,15 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared from modules.logging_colors import logger -if torch.cuda.is_available(): - from llama_cpp_cuda import Llama +if torch.cuda.is_available() and not torch.version.hip: + try: + from llama_cpp_cuda import Llama + except: + from llama_cpp import Llama else: from llama_cpp import Llama + class LlamacppHF(PreTrainedModel): def __init__(self, model): super().__init__(PretrainedConfig()) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index c6e6ec54..048f9eac 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,11 +1,3 @@ -''' -Based on -https://github.com/abetlen/llama-cpp-python - -Documentation: -https://abetlen.github.io/llama-cpp-python/ -''' - import re from functools import partial @@ -15,8 +7,11 @@ from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger -if torch.cuda.is_available(): - from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList +if torch.cuda.is_available() and not torch.version.hip: + try: + from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList + except: + from llama_cpp import Llama, LlamaCache, LogitsProcessorList else: from llama_cpp import Llama, LlamaCache, LogitsProcessorList