diff --git a/CMakeLists.txt b/CMakeLists.txt index a2404548f..adc76d94e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) @@ -305,6 +306,23 @@ if (LLAMA_METAL) ) endif() +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + set(cxx_flags ${cxx_flags} -Wno-cast-qual) + set(c_flags ${c_flags} -Wno-cast-qual) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + else() + message(WARNING "MPI not found") + endif() +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) @@ -473,6 +491,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_SOURCES_METAL} + ${GGML_SOURCES_MPI} ${GGML_SOURCES_EXTRA} )