From 4833ac209da6a427de64f97e8f403dcdc5de6bc3 Mon Sep 17 00:00:00 2001 From: AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> Date: Mon, 5 Feb 2024 07:08:24 +0000 Subject: [PATCH] [SYCL] Fix cpy with dims of 3 (#5289) * Fix cpy with dims of 3 * rm asserts --------- Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> --- ggml-sycl.cpp | 194 +++++++++++++++++++++++++++++--------------------- 1 file changed, 114 insertions(+), 80 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 51445b5e7..a03df4c65 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -7693,6 +7693,13 @@ static void cpy_1_f16_f16(const char * cxi, char * cdsti) { *dsti = *xi; } +static void cpy_1_f16_f32(const char * cxi, char * cdsti) { + const sycl::half *xi = (const sycl::half *)cxi; + float *dsti = (float *)cdsti; + + *dsti = *xi; +} + static void cpy_1_i16_i16(const char * cxi, char * cdsti) { const int16_t *xi = (const int16_t *)cxi; int16_t *dsti = (int16_t *)cdsti; @@ -7709,9 +7716,9 @@ static void cpy_1_i32_i32(const char * cxi, char * cdsti) { template static void cpy_f32_f16(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, - const sycl::nd_item<3> &item_ct1) { + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -7721,15 +7728,17 @@ static void cpy_f32_f16(const char * cx, char * cdst, const int ne, // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets - const int i02 = i / (ne00*ne01); - const int i01 = (i - i02*ne01*ne00) / ne00; - const int i00 = i - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + const int i03 = i/(ne00 * ne01 * ne02); + const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i12 = i / (ne10*ne11); - const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = i - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + const int i13 = i/(ne10 * ne11 * ne12); + const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; cpy_1(cx + x_offset, cdst + dst_offset); } @@ -7823,9 +7832,9 @@ static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { template static void cpy_f32_q(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, - const sycl::nd_item<3> &item_ct1) { + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) { const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; @@ -7834,15 +7843,17 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, return; } - const int i02 = i / (ne00*ne01); - const int i01 = (i - i02*ne01*ne00) / ne00; - const int i00 = (i - i02*ne01*ne00 - i01*ne00); - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + const int i03 = i/(ne00 * ne01 * ne02); + const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i12 = i / (ne10*ne11); - const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + const int i13 = i/(ne10 * ne11 * ne12); + const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -10599,10 +10610,12 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; @@ -10615,8 +10628,8 @@ static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, nb00, nb01, - nb02, ne10, ne11, nb10, nb11, nb12, + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); }); } @@ -10624,10 +10637,12 @@ static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne, static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; @@ -10640,8 +10655,8 @@ static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, nb00, nb01, - nb02, ne10, ne11, nb10, nb11, nb12, + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); }); } @@ -10649,10 +10664,12 @@ static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne, static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { GGML_ASSERT(ne % QK8_0 == 0); @@ -10661,17 +10678,20 @@ static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_f32_q( - cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, item_ct1); + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); }); } static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { GGML_ASSERT(ne % QK4_0 == 0); @@ -10680,17 +10700,20 @@ static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_f32_q( - cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, item_ct1); + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); }); } static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { GGML_ASSERT(ne % QK4_1 == 0); @@ -10699,17 +10722,20 @@ static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_f32_q( - cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, item_ct1); + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); }); } static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; @@ -10722,8 +10748,8 @@ static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, nb00, nb01, - nb02, ne10, ne11, nb10, nb11, nb12, + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); }); } @@ -10731,10 +10757,12 @@ static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne, static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; @@ -10747,8 +10775,8 @@ static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, nb00, nb01, - nb02, ne10, ne11, nb10, nb11, nb12, + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); }); } @@ -10756,10 +10784,12 @@ static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne, static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int ne01, - const int nb00, const int nb01, - const int nb02, const int ne10, - const int ne11, const int nb10, - const int nb11, const int nb12, + const int ne02, const int nb00, + const int nb01, const int nb02, + const int nb03, const int ne10, + const int ne11, const int ne12, + const int nb10, const int nb11, + const int nb12, const int nb13, dpct::queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; @@ -10772,8 +10802,8 @@ static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne, sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, nb00, nb01, - nb02, ne10, ne11, nb10, nb11, nb12, + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); }); } @@ -13910,19 +13940,23 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1, const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - GGML_ASSERT(src0->ne[3] == 1); + const int64_t ne02 = src0->ne[2]; + const int64_t nb00 = src0->nb[0]; const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; + const int64_t nb03 = src0->nb[3]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; - GGML_ASSERT(src1->ne[3] == 1); + const int64_t ne12 = src1->ne[2]; + const int64_t nb10 = src1->nb[0]; const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0]; @@ -13934,21 +13968,21 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1, char * src1_ddc = (char *) src1_extra->data_device[g_main_device_index]; if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) { - ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type));