From e2b065071c5fc8ac5697d12ca343551faee465cc Mon Sep 17 00:00:00 2001 From: Neo Zhang <14088817+arthw@users.noreply.github.com> Date: Tue, 28 May 2024 17:53:37 +0800 Subject: [PATCH] [SYCL]fix ggml_sycl_mul_mat_id() to match the change of api (#7436) * fix mul_mat_id to match the change of api * rm comment * rm unused or duplicated code, rename as review comment --- ggml-sycl.cpp | 273 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 219 insertions(+), 54 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index d5384b2e0..022a52aeb 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -2944,6 +2944,57 @@ namespace dpct using shared_memory = detail::device_memory; + template + inline T atomic_fetch_add(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T atomic_fetch_add(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_add(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_add(addr, operand, memoryOrder); + } + } // COPY from DPCT head files #define GGML_COMMON_DECL_SYCL @@ -3060,6 +3111,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d bool ggml_backend_is_sycl(ggml_backend_t backend); int ggml_backend_sycl_get_device(ggml_backend_t backend); int get_main_device(); +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer); void print_ggml_tensor(const char*name, struct ggml_tensor *src); void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt); @@ -15459,22 +15511,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) { } #endif +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +__dpct_inline__ static void k_copy_src1_to_contiguous( + const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, + int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, + const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, + const sycl::nd_item<3> &item_ct1, int &src1_row) { + int32_t iid1 = item_ct1.get_group(2); + int32_t id = item_ct1.get_group(1); + + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + if (item_ct1.get_local_id(2) == 0) { + src1_row = + dpct::atomic_fetch_add( + cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + /* + DPCT1065:194: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); + +#pragma unroll + for (int i = item_ct1.get_local_id(2); i < ne10; + i += item_ct1.get_local_range(2)) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +__dpct_inline__ static void k_copy_dst_from_contiguous( + char *__restrict__ dst_original, const char *__restrict__ dst_contiguous, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1, + size_t nb2, const sycl::nd_item<3> &item_ct1) { + int32_t i = item_ct1.get_group(2); + + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; + + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); + +#pragma unroll + for (int j = item_ct1.get_local_id(2); j < ne0; + j += item_ct1.get_local_range(2)) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { - GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT && - "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); + const ggml_tensor *ids = dst->src[2]; + GGML_TENSOR_BINARY_OP_LOCALS + const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - const size_t nb11 = src1->nb[1]; - const size_t nb1 = dst->nb[1]; - - const int32_t id = ((int32_t *)dst->op_params)[0]; - const int32_t n_as = src0->ne[2]; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; std::vector ids_host(ggml_nbytes(ids)); - const char *ids_dev = (const char *)ids->data; + const char * ids_dev = (const char *) ids->data; SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); @@ -15514,24 +15630,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, src0_row.ne[2] = 1; src0_row.ne[3] = 1; - src0_row.nb[3] = src0->nb[2]; + src0_row.nb[3] = nb02; - if (src1->ne[1] == 1) { - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = - *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] + - id * ids->nb[0]); + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; - GGML_ASSERT(row_id >= 0 && row_id < n_as); + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; src0_row_extra.data_device[g_main_device] = - src0_original + row_id * src0->nb[2]; + src0_original + i02*nb02; src1_row_extra.data_device[g_main_device] = - src1_original + i01 * src1->nb[1]; + src1_original + + i11*nb11 + i12*nb12; dst_row_extra.data_device[g_main_device] = - dst_original + i01 * dst->nb[1]; + dst_original + i1*nb1 + i2*nb2; ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); + } } } else { sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); @@ -15540,64 +15672,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, src1_row_extra.data_device[g_main_device] = src1_contiguous.get(); dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); - for (int32_t row_id = 0; row_id < n_as; ++row_id) { + for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11, - src1_original + i01 * nb11, nb11))); - num_src1_rows++; } if (num_src1_rows == 0) { continue; } - src0_row_extra.data_device[g_main_device] = - src0_original + row_id * src0->nb[2]; + sycl_pool_alloc dev_cur_src1_row(1); + sycl_pool_alloc dev_row_mapping(num_src1_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); + sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor src1_row_acc(cgh); + + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + int *__restrict dev_cur_src1_row_get = + dev_cur_src1_row.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + size_t ids_nb_ct6 = ids->nb[1]; + size_t ids_nb_ct7 = ids->nb[0]; + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_cur_src1_row_get, + dev_row_mapping_get, ids_dev, i02, + ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, + item_ct1, src1_row_acc); + }); + }); + } + + src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); src1_row.ne[1] = num_src1_rows; - dst_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); - num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); + sycl::range<3> grid_dims(1, 1, num_src1_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); - if (row_id_i != row_id) { - continue; - } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( - dst_original + i01 * nb1, - dst_contiguous.get() + num_src1_rows * nb1, nb1))); - num_src1_rows++; + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } - - if (dst->backend == GGML_BACKEND_TYPE_CPU) { - SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -16580,10 +16746,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe UNUSED(buffer); } -// unused at the moment -//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { -// return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name; -//} +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name; +} GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;