diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index e09c1a74..069e1007 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -10,11 +10,15 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared from modules.logging_colors import logger -if torch.cuda.is_available(): - from llama_cpp_cuda import Llama +if torch.cuda.is_available() and not torch.version.hip: + try: + from llama_cpp_cuda import Llama + except: + from llama_cpp import Llama else: from llama_cpp import Llama + class LlamacppHF(PreTrainedModel): def __init__(self, model): super().__init__(PretrainedConfig()) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index c6e6ec54..048f9eac 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,11 +1,3 @@ -''' -Based on -https://github.com/abetlen/llama-cpp-python - -Documentation: -https://abetlen.github.io/llama-cpp-python/ -''' - import re from functools import partial @@ -15,8 +7,11 @@ from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger -if torch.cuda.is_available(): - from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList +if torch.cuda.is_available() and not torch.version.hip: + try: + from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList + except: + from llama_cpp import Llama, LlamaCache, LogitsProcessorList else: from llama_cpp import Llama, LlamaCache, LogitsProcessorList