From 948f4ec7c5bff92b18e63303f2b2d1645bccd943 Mon Sep 17 00:00:00 2001 From: Neo Zhang <14088817+arthw@users.noreply.github.com> Date: Mon, 13 May 2024 18:11:26 +0800 Subject: [PATCH] [SYCL] rm wait() (#7233) --- ggml-sycl.cpp | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e93d2af63..724070eb9 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15564,26 +15564,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use syclGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; - - SYCL_CHECK( - syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( @@ -15595,7 +15575,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, nb11 / nb10, nb12 / nb10, beta, (char *)dst_t, cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } else { const int ne23 = ne12*ne13; @@ -15626,7 +15605,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, nb02, nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); }); - }).wait(); + }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, @@ -15637,9 +15616,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } -#endif if (no_mixed_dtypes) { const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);