diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 83f223fd7..3579a311a 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -68,7 +68,8 @@ else() target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") - target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl) + add_compile_definitions(GGML_SYCL_NVIDIA) + target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index c2f28bb49..d1b5dd87c 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -1689,9 +1689,14 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ q }, + a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#else + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#endif } template @@ -1754,14 +1759,22 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; +#ifdef GGML_SYCL_NVIDIA sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), + oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, + matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), + matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, + &(matrix_info->groupsize_info)); +#else + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); +#endif q.submit([&](sycl::handler &cgh) { @@ -1783,10 +1796,16 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); +#ifdef GGML_SYCL_NVIDIA oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); + oneapi::mkl::backend_selector{ q }, a_trans, b_trans, m, n, k, + alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, + batch_size); +#else + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, + stride_c, batch_size); +#endif } } // namespace detail diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1310981e5..135efb521 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2573,12 +2573,17 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; #if !GGML_SYCL_DNNL +# ifdef GGML_SYCL_NVIDIA SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, - dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, - src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), + oneapi::mkl::backend_selector{ *stream }, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, + ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# else + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, + dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# endif #else auto dnnl_stream = ctx.stream_dnnl(stream); DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index e61cdc2ca..ef9af0b76 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr try { // Perform matrix multiplication using oneMKL GEMM - oneapi::mkl::blas::column_major::gemm(*stream, - oneapi::mkl::transpose::nontrans, src1_op, - ne0, ne1, ne01, - alpha, - src0_d, ne00, - src1_d, ldb, - beta, - dst_d, ne0); +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ *stream }, + oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, + ne00, src1_d, ldb, beta, dst_d, ne0); +#else + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, + src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); +#endif } catch (sycl::exception const& exc) { std::cerr << exc.what() << std::endl;