mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add checks for ROCm and unsupported architectures to llama_cpp_cuda loading (#3225)
This commit is contained in:
parent
74fc5dd873
commit
1141987a0d
@ -10,11 +10,15 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and not torch.version.hip:
|
||||
try:
|
||||
from llama_cpp_cuda import Llama
|
||||
except:
|
||||
from llama_cpp import Llama
|
||||
else:
|
||||
from llama_cpp import Llama
|
||||
|
||||
|
||||
class LlamacppHF(PreTrainedModel):
|
||||
def __init__(self, model):
|
||||
super().__init__(PretrainedConfig())
|
||||
|
@ -1,11 +1,3 @@
|
||||
'''
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
@ -15,8 +7,11 @@ from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
from modules.logging_colors import logger
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and not torch.version.hip:
|
||||
try:
|
||||
from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList
|
||||
except:
|
||||
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
|
||||
else:
|
||||
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user