Add checks for ROCm and unsupported architectures to llama_cpp_cuda loading (#3225)

This commit is contained in:
jllllll 2023-07-24 09:25:36 -05:00 committed by GitHub
parent 74fc5dd873
commit 1141987a0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 12 deletions

View File

@ -10,11 +10,15 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
if torch.cuda.is_available():
if torch.cuda.is_available() and not torch.version.hip:
try:
from llama_cpp_cuda import Llama
except:
from llama_cpp import Llama
else:
from llama_cpp import Llama
class LlamacppHF(PreTrainedModel):
def __init__(self, model):
super().__init__(PretrainedConfig())

View File

@ -1,11 +1,3 @@
'''
Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
import re
from functools import partial
@ -15,8 +7,11 @@ from modules import shared
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
if torch.cuda.is_available():
if torch.cuda.is_available() and not torch.version.hip:
try:
from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList
except:
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
else:
from llama_cpp import Llama, LlamaCache, LogitsProcessorList