diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index d1e1a342..f3f3f560 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -7,30 +7,47 @@ from modules import shared from modules.cache_utils import process_llamacpp_cache +imported_module = None + + def llama_cpp_lib(): + global imported_module + return_lib = None if shared.args.cpu: + if imported_module and imported_module != 'llama_cpp': + raise Exception(f"Cannot import 'llama_cpp' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.") try: return_lib = importlib.import_module('llama_cpp') + imported_module = 'llama_cpp' except: pass if shared.args.tensorcores and return_lib is None: + if imported_module and imported_module != 'llama_cpp_cuda_tensorcores': + raise Exception(f"Cannot import 'llama_cpp_cuda_tensorcores' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.") try: return_lib = importlib.import_module('llama_cpp_cuda_tensorcores') + imported_module = 'llama_cpp_cuda_tensorcores' except: pass if return_lib is None: + if imported_module and imported_module != 'llama_cpp_cuda': + raise Exception(f"Cannot import 'llama_cpp_cuda' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.") try: return_lib = importlib.import_module('llama_cpp_cuda') + imported_module = 'llama_cpp_cuda' except: pass if return_lib is None and not shared.args.cpu: + if imported_module and imported_module != 'llama_cpp': + raise Exception(f"Cannot import 'llama_cpp' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.") try: return_lib = importlib.import_module('llama_cpp') + imported_module = 'llama_cpp' except: pass