From 3bcd40b3c593d14261fb2abfabad3c0fb5b9e318 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Thu, 7 Nov 2024 18:19:10 +1100 Subject: [PATCH] Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (#10133) * rwkv6: rename to wkv6 * rwkv6: support avx2 avx512 armv8 armv9 * rwkv6: update cuda file name * rwkv6: rename params * wkv on sycl * sycl: add some ops * sycl: Enhance OP support judgment * wkv6: drop armv9 and tranfer to GGML style ggml-ci * sync : ggml * update the function to use appropriate types * fix define error * Update ggml/src/ggml-cpu.c * add appropriate asserts * move element-wise functions outside * put the declaration outside the loop * rewrite to be more inline with the common pattern for distributing threads * use recommended way GGML_TENSOR_LOCALS --------- Co-authored-by: Georgi Gerganov Co-authored-by: Diego Devesa Co-authored-by: Plamen Minev Co-authored-by: Yuri Khrustalev Co-authored-by: Meng, Hengyu --- docs/backend/SYCL.md | 2 +- ggml/include/ggml.h | 4 +- ggml/src/ggml-cpu.c | 208 +++- ggml/src/ggml-cuda.cu | 8 +- ggml/src/ggml-cuda/rwkv-wkv.cuh | 5 - ggml/src/ggml-cuda/{rwkv-wkv.cu => wkv6.cu} | 6 +- ggml/src/ggml-cuda/wkv6.cuh | 5 + ggml/src/ggml-sycl.cpp | 1121 +++---------------- ggml/src/ggml-sycl/backend.hpp | 3 + ggml/src/ggml-sycl/common.cpp | 40 + ggml/src/ggml-sycl/common.hpp | 258 +++++ ggml/src/ggml-sycl/concat.cpp | 1 + ggml/src/ggml-sycl/element_wise.cpp | 1011 +++++++++++++++++ ggml/src/ggml-sycl/element_wise.hpp | 76 ++ ggml/src/ggml-sycl/outprod.cpp | 55 + ggml/src/ggml-sycl/outprod.hpp | 11 + ggml/src/ggml-sycl/presets.hpp | 6 + ggml/src/ggml-sycl/wkv6.cpp | 138 +++ ggml/src/ggml-sycl/wkv6.hpp | 10 + ggml/src/ggml.c | 12 +- src/llama.cpp | 8 +- tests/test-backend-ops.cpp | 16 +- 22 files changed, 1977 insertions(+), 1027 deletions(-) delete mode 100644 ggml/src/ggml-cuda/rwkv-wkv.cuh rename ggml/src/ggml-cuda/{rwkv-wkv.cu => wkv6.cu} (93%) create mode 100644 ggml/src/ggml-cuda/wkv6.cuh create mode 100644 ggml/src/ggml-sycl/element_wise.cpp create mode 100644 ggml/src/ggml-sycl/element_wise.hpp create mode 100644 ggml/src/ggml-sycl/outprod.cpp create mode 100644 ggml/src/ggml-sycl/outprod.hpp create mode 100644 ggml/src/ggml-sycl/wkv6.cpp create mode 100644 ggml/src/ggml-sycl/wkv6.hpp diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index ea34182e4..bc8c0f886 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -377,7 +377,7 @@ found 2 SYCL devices: |Chosen Device ID|Setting| |-|-| -|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action| +|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action| |1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`| |0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`| diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8a0bcbff8..0d143d2fe 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -509,7 +509,7 @@ extern "C" { GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, - GGML_OP_RWKV_WKV, + GGML_OP_RWKV_WKV6, GGML_OP_UNARY, @@ -1819,7 +1819,7 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); - GGML_API struct ggml_tensor * ggml_rwkv_wkv( + GGML_API struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, struct ggml_tensor * k, struct ggml_tensor * v, diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 0cb5b824a..98c3e21ae 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -11642,24 +11642,30 @@ static void ggml_compute_forward_add_rel_pos( } } -// ggml_compute_forward_rwkv_wkv +// ggml_compute_forward_rwkv_wkv6 -static void ggml_compute_forward_rwkv_wkv_f32( +static void ggml_compute_forward_rwkv_wkv6_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const size_t T = dst->src[1]->ne[3]; - const size_t C = dst->ne[0]; - const size_t H = dst->src[1]->ne[2]; - const size_t n_seqs = dst->src[5]->ne[1]; + const int64_t T = dst->src[1]->ne[3]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[2]; + const int64_t n_seqs = dst->src[5]->ne[1]; + const int64_t head_size = C / HEADS; float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; - if (params->ith != 0) { + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { return; } - memset(dst_data, 0, T * C * sizeof(float)); + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -11667,54 +11673,160 @@ static void ggml_compute_forward_rwkv_wkv_f32( float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - size_t t_stride = H * (C / H); + size_t t_stride = HEADS * head_size; // Same to C - size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); + size_t h_stride = C / HEADS; + GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + size_t h_stride_2d = head_size * head_size; - // basically fused operations: - // dst = r @ (time_faaaa * (k @ v) + state), - // state = time_decay * state + (k @ v), - // recursive through each token - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = (C / H) * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + if (ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + ggml_barrier(params->threadpool); - for (size_t h = 0; h < H; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; + #if defined(__AVX__) && !defined(__AVX512F__) + #define GGML_F32X GGML_F32x8 + #define GGML_F32X_SET1 GGML_F32x8_SET1 + #define GGML_F32X_LOAD GGML_F32x8_LOAD + #define GGML_F32X_STORE GGML_F32x8_STORE + #define GGML_F32X_MUL GGML_F32x8_MUL + #define GGML_F32X_FMA GGML_F32x8_FMA + #define WKV_VECTOR_SIZE 8 + #elif defined(__AVX512F__) + #define GGML_F32X GGML_F32x16 + #define GGML_F32X_SET1 GGML_F32x16_SET1 + #define GGML_F32X_LOAD GGML_F32x16_LOAD + #define GGML_F32X_STORE GGML_F32x16_STORE + #define GGML_F32X_MUL GGML_F32x16_MUL + #define GGML_F32X_FMA GGML_F32x16_FMA + #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define GGML_F32X GGML_F32x4 + #define GGML_F32X_SET1 GGML_F32x4_SET1 + #define GGML_F32X_LOAD GGML_F32x4_LOAD + #define GGML_F32X_STORE GGML_F32x4_STORE + #define GGML_F32X_MUL GGML_F32x4_MUL + #define GGML_F32X_FMA GGML_F32x4_FMA + #define WKV_VECTOR_SIZE 4 + #endif - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - // RWKV v6: different time_decay for each token. - float time_decay_val = time_decay[t_h_i_offset]; + #ifdef WKV_VECTOR_SIZE + const int64_t vec_count = head_size / WKV_VECTOR_SIZE; - for (size_t j = 0; j < C / H; j ++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + float time_decay_val = time_decay[t_h_i_offset]; + + // Broadcast scalar values to vectors + GGML_F32X k_vec = GGML_F32X_SET1(k_val); + GGML_F32X r_vec = GGML_F32X_SET1(r_val); + GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val); + GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); + + for (int64_t j = 0; j < vec_count; j++) { + size_t base_j = j * WKV_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + + // Load x elements at once + GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); + GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); + + // Compute kv = v * k + GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); + + // Compute temp = kv * time_faaaa + prev_state + GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec); + + // Update dst: dst += temp * r + dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec); + GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + + // Update state: state = prev_state * time_decay + kv + GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec); + GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec); + } + + // Handle remaining elements, this will not be used. + for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } } } } - } + + #else + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (int64_t j = 0; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } + #endif } -static void ggml_compute_forward_rwkv_wkv( + +static void ggml_compute_forward_rwkv_wkv6( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_rwkv_wkv_f32(params, dst); + ggml_compute_forward_rwkv_wkv6_f32(params, dst); } break; default: { @@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add_rel_pos(params, tensor); } break; - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: { - ggml_compute_forward_rwkv_wkv(params, tensor); + ggml_compute_forward_rwkv_wkv6(params, tensor); } break; case GGML_OP_MAP_UNARY: { @@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e68e40550..e27c8e87d 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -36,7 +36,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/rwkv-wkv.cuh" +#include "ggml-cuda/wkv6.cuh" #include #include @@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; - case GGML_OP_RWKV_WKV: - ggml_cuda_op_rwkv_wkv(ctx, dst); + case GGML_OP_RWKV_WKV6: + ggml_cuda_op_rwkv_wkv6(ctx, dst); break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); @@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: return true; case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/rwkv-wkv.cuh deleted file mode 100644 index 13795247f..000000000 --- a/ggml/src/ggml-cuda/rwkv-wkv.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#include "common.cuh" - -#define CUDA_WKV_BLOCK_SIZE 64 - -void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/wkv6.cu similarity index 93% rename from ggml/src/ggml-cuda/rwkv-wkv.cu rename to ggml/src/ggml-cuda/wkv6.cu index 098e92d35..42578341a 100644 --- a/ggml/src/ggml-cuda/rwkv-wkv.cu +++ b/ggml/src/ggml-cuda/wkv6.cu @@ -1,5 +1,5 @@ #include "common.cuh" -#include "rwkv-wkv.cuh" +#include "wkv6.cuh" static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { const int tid = threadIdx.x; @@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const } } -void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * k_d = (const float *)dst->src[0]->data; const float * v_d = (const float *)dst->src[1]->data; const float * r_d = (const float *)dst->src[2]->data; @@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64 rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); } diff --git a/ggml/src/ggml-cuda/wkv6.cuh b/ggml/src/ggml-cuda/wkv6.cuh new file mode 100644 index 000000000..a7124ee51 --- /dev/null +++ b/ggml/src/ggml-cuda/wkv6.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_WKV_BLOCK_SIZE 64 + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index a62c67f4f..255bc64c6 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -1194,272 +1194,8 @@ typedef void (*ggml_sycl_op_mul_mat_t)( float *dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, const queue_ptr &stream); -typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream); -static __dpct_inline__ float op_repeat(const float a, const float b) { - return b; - GGML_UNUSED(a); -} -static __dpct_inline__ float op_add(const float a, const float b) { - return a + b; -} - -static __dpct_inline__ float op_mul(const float a, const float b) { - return a * b; -} - -static __dpct_inline__ float op_div(const float a, const float b) { - return a / b; -} - -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); - } -} - -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); -} - -static void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= ne) { - return; - } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; - } -} - -static void gelu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = 0.5f * xi * - (1.0f + - sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); -} - -static void silu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); -} - -static void gelu_quick_f32(const float *x, float *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_QUICK_COEF = -1.702f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); -} - -static void tanh_f32(const float *x, float *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = sycl::tanh((float)(x[i])); -} - -static void relu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmax((float)(x[i]), (float)0); -} - -static void hardsigmoid_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static void hardswish_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = sycl::fmax((float)(x[i]), (float)0) + - sycl::fmin((float)(x[i]), 0.0f) * negative_slope; -} - -static void sqr_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] * x[i]; -} - -static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = i10 / sf0; - int i01 = i11 / sf1; - int i02 = i12 / sf2; - int i03 = i13 / sf3; - - dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - -static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, - const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (nidx >= ne0) { - return; - } - - // operation - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (nidx < ne00 && item_ct1.get_group(1) < ne01 && - item_ct1.get_group(0) < ne02) { - int offset_src = nidx + item_ct1.get_group(1) * ne00 + - item_ct1.get_group(0) * ne00 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - dst[offset_dst] = 0.0f; - } -} template static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, @@ -2148,297 +1884,6 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens (void) dst; } -template -struct bin_bcast_sycl { - template - void operator()(ggml_backend_sycl_context & ctx, - const struct ggml_tensor *src0, - const struct ggml_tensor *src1, struct ggml_tensor *dst, - const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne0[] = {ne0, ne1, ne2, ne3}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb0[] = {nb0, nb1, nb2, nb3}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne0); - collapse(cne1); - } - } - { - int64_t ne0 = cne0[0]; - int64_t ne1 = cne0[1]; - int64_t ne2 = cne0[2]; - int64_t ne3 = cne0[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb0[0]; - size_t nb1 = cnb0[1]; - size_t nb2 = cnb0[2]; - size_t nb3 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12, - s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s11, s12, s13, - item_ct1); - }); - } - } - } -}; - -static void acc_f32_sycl(const float *x, const float *y, float *dst, - const int n_elements, const int ne10, const int ne11, - const int ne12, const int nb1, const int nb2, - const int offset, queue_ptr stream) { - int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, - item_ct1); - }); -} - -static void gelu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_f32(x, dst, k, item_ct1); - }); -} - -static void silu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - silu_f32(x, dst, k, item_ct1); - }); -} - -static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_quick_f32(x, dst, k, item_ct1); - }); -} - -static void tanh_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - tanh_f32(x, dst, k, item_ct1); - }); -} - -static void relu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - relu_f32(x, dst, k, item_ct1); - }); -} - -static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardsigmoid_f32(x, dst, k, item_ct1); - }); -} - -static void hardswish_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardswish_f32(x, dst, k, item_ct1); - }); -} - -static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, - const float negative_slope, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - leaky_relu_f32(x, dst, k, negative_slope, item_ct1); - }); -} - -static void sqr_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqr_f32(x, dst, k, item_ct1); - }); -} - -static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { - upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - -static void pad_f32_sycl(const float *x, float *dst, const int ne00, - const int ne01, const int ne02, const int ne0, - const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; - sycl::range<3> gridDim(ne2, ne1, num_blocks); - stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); - }); -} static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, const int ky, const int kx_padded, @@ -2816,6 +2261,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } +static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, + const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = 256 * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor shared_data( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + sycl::local_accessor shared_indices( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + const int tid = item_ct1.get_local_id(2); + const int row = item_ct1.get_global_id(1); + + float max_val = -INFINITY; + int max_idx = -1; + + for (int col = tid; col < ncols; col += 256) { + float val = x[row * ncols + col]; + if (val > max_val) { + max_val = val; + max_idx = col; + } + } + + shared_data[tid] = max_val; + shared_indices[tid] = max_idx; + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int stride = 256/2; stride > 0; stride >>= 1) { + if (tid < stride) { + float val1 = shared_data[tid]; + float val2 = shared_data[tid + stride]; + if (val2 > val1) { + shared_data[tid] = val2; + shared_indices[tid] = shared_indices[tid + stride]; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + + if (tid == 0) { + dst[row] = shared_indices[0]; + } + }); + }); +} static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, @@ -2946,33 +2443,6 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_te } } -template -inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, - (sycl::half *)dst_dd, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, - main_stream); - } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, - main_stream); - } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, - main_stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } -} static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -2986,230 +2456,6 @@ static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tens (void) src1_d; } -inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes - - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); - - (void) dst; -} - -inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - float negative_slope; - memcpy(&negative_slope, dst->op_params, sizeof(float)); - - leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const float sf0 = (float)dst->ne[0]/src0->ne[0]; - const float sf1 = (float)dst->ne[1]/src0->ne[1]; - const float sf2 = (float)dst->ne[2]/src0->ne[2]; - const float sf3 = (float)dst->ne[3]/src0->ne[3]; - - upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - - pad_f32_sycl(src0_dd, dst_dd, - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} inline void ggml_sycl_op_mul_mat_sycl( ggml_backend_sycl_context & ctx, @@ -3379,6 +2625,23 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens (void) src1_dd; } +inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne = ggml_nelements(src0); + + sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -3419,6 +2682,25 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten (void) src1_dd; } +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, @@ -3489,46 +2771,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tenso (void) src1_dd; } -static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const ggml_sycl_op_flatten_t op) try { - const int64_t nrows0 = ggml_nrows(src0); - - const bool use_src1 = src1 != nullptr; - const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; - - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - - // dd = data device - float * src0_ddf = (float *) src0->data; - float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; - float * dst_ddf = (float *) dst->data; - - ggml_sycl_pool_alloc src0_f(ctx.pool()); - ggml_sycl_pool_alloc src1_f(ctx.pool()); - ggml_sycl_pool_alloc dst_f(ctx.pool()); - - ggml_sycl_set_device(ctx.device); - queue_ptr main_stream = ctx.stream(); - // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", - // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device); - - // do the computation - op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); - // print_ggml_tensor("tensor", dst); -} -catch (sycl::exception const &exc) { - - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static bool peer_access_enabled = false; @@ -3908,115 +3150,24 @@ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tenso GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - - static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } +static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { @@ -4632,6 +3783,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col); } +static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum); +} + static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows); @@ -4642,6 +3798,11 @@ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort); } +static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax); +} + static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -4673,6 +3834,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_func_t func; switch (tensor->op) { + case GGML_OP_ARGMAX: + func = ggml_sycl_argmax; + break; case GGML_OP_CONV_TRANSPOSE_1D: func = ggml_sycl_op_conv_transpose_1d; break; @@ -4686,19 +3850,32 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens func = ggml_sycl_dup; break; case GGML_OP_ADD: + case GGML_OP_ADD1: // TODO: more efficient implementation func = ggml_sycl_add; break; + case GGML_OP_SUB: + func = ggml_sycl_sub; + break; case GGML_OP_ACC: func = ggml_sycl_acc; break; case GGML_OP_MUL: func = ggml_sycl_mul; break; + case GGML_OP_LOG: + func = ggml_sycl_log; + break; case GGML_OP_DIV: func = ggml_sycl_div; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_NEG: + func = ggml_sycl_neg; + break; + case GGML_UNARY_OP_STEP: + func = ggml_sycl_step; + break; case GGML_UNARY_OP_GELU: func = ggml_sycl_gelu; break; @@ -4714,12 +3891,18 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_UNARY_OP_RELU: func = ggml_sycl_relu; break; + case GGML_UNARY_OP_SIGMOID: + func = ggml_sycl_sigmoid; + break; case GGML_UNARY_OP_HARDSIGMOID: func = ggml_sycl_hardsigmoid; break; case GGML_UNARY_OP_HARDSWISH: func = ggml_sycl_hardswish; break; + case GGML_UNARY_OP_EXP: + func = ggml_sycl_exp; + break; default: return false; } @@ -4757,12 +3940,24 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens } func = ggml_sycl_mul_mat_id; break; + case GGML_OP_OUT_PROD: + func = ggml_sycl_op_out_prod; + break; case GGML_OP_SCALE: func = ggml_sycl_scale; break; case GGML_OP_SQR: func = ggml_sycl_sqr; break; + case GGML_OP_SQRT: + func = ggml_sycl_sqrt; + break; + case GGML_OP_SIN: + func = ggml_sycl_sin; + break; + case GGML_OP_COS: + func = ggml_sycl_cos; + break; case GGML_OP_CLAMP: func = ggml_sycl_clamp; break; @@ -4794,6 +3989,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_POOL_2D: func = ggml_sycl_pool2d; break; + case GGML_OP_SUM: + func = ggml_sycl_sum; + break; case GGML_OP_SUM_ROWS: func = ggml_sycl_sum_rows; break; @@ -4803,6 +4001,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_TIMESTEP_EMBEDDING: func = ggml_sycl_op_timestep_embedding; break; + case GGML_OP_RWKV_WKV6: + func = ggml_sycl_op_rwkv_wkv6; + break; default: return false; } @@ -5125,13 +4326,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); default: return false; @@ -5168,6 +4373,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g } return true; } break; + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -5213,10 +4420,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - int dim = op->op_params[0]; - return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_DUP: + case GGML_OP_ARGMAX: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_REPEAT: @@ -5225,11 +4432,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_LOG: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: return true; case GGML_OP_CONT: @@ -5243,6 +4456,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g // TODO: add support for the new F32 operations return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_2D: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -5251,6 +4465,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_RWKV_WKV6: return true; default: return false; @@ -5268,9 +4483,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_ return buft_ctx->device == sycl_ctx->device; } +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return op->ne[1]; // this will increse the speed of prefill in test + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID; + return get_op_batch_size(op) >= min_batch_size; GGML_UNUSED(dev); } diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index d21b5f8dd..85748a5b4 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -26,5 +26,8 @@ #include "softmax.hpp" #include "tsembd.hpp" #include "im2col.hpp" +#include "wkv6.hpp" +#include "outprod.hpp" +#include "element_wise.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index cf5291b31..97ab2003c 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block } return sycl_down_blk_size; } + +void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const ggml_sycl_op_flatten_t op) try { + const int64_t nrows0 = ggml_nrows(src0); + + const bool use_src1 = src1 != nullptr; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + // dd = data device + float * src0_ddf = (float *) src0->data; + float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; + float * dst_ddf = (float *) dst->data; + + ggml_sycl_pool_alloc src0_f(ctx.pool()); + ggml_sycl_pool_alloc src1_f(ctx.pool()); + ggml_sycl_pool_alloc dst_f(ctx.pool()); + + ggml_sycl_set_device(ctx.device); + queue_ptr main_stream = ctx.stream(); + // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", + // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device); + + // do the computation + op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + // print_ggml_tensor("tensor", dst); +} +catch (sycl::exception const &exc) { + + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index bc0faa867..4549fa5e9 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); +typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream); + +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + + +template +struct bin_bcast_sycl { + template + void operator()(ggml_backend_sycl_context & ctx, + const struct ggml_tensor *src0, + const struct ggml_tensor *src1, struct ggml_tensor *dst, + const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10/ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb0[] = {nb0, nb1, nb2, nb3}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne0); + collapse(cne1); + } + } + { + int64_t ne0 = cne0[0]; + int64_t ne1 = cne0[1]; + int64_t ne2 = cne0[2]; + int64_t ne3 = cne0[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb0[0]; + size_t nb1 = cnb0[1]; + size_t nb2 = cnb0[2]; + size_t nb3 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12, + s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s11, s12, s13, + item_ct1); + }); + } + } + } +}; + +template +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, + (sycl::half *)dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, + main_stream); + } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, + main_stream); + } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { + op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, + main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + + +void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const ggml_sycl_op_flatten_t op); + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index 632eedb9d..c90c452d8 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); }); break; + // dim >=2 will be dispatched to the default path default: stream->parallel_for( sycl::nd_range<3>(gridDim * diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp new file mode 100644 index 000000000..e5cd736eb --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -0,0 +1,1011 @@ +#include "common.hpp" +#include "element_wise.hpp" + +void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= ne) { + return; + } + int src1_idx = i - offset; + int oz = src1_idx / nb2; + int oy = (src1_idx - (oz * nb2)) / nb1; + int ox = src1_idx % nb1; + if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + +void gelu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = 0.5f * xi * + (1.0f + + sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); +} + +void silu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); +} + +void gelu_quick_f32(const float *x, float *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const float GELU_QUICK_COEF = -1.702f; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); +} + +void tanh_f32(const float *x, float *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::tanh((float)(x[i])); +} + +void relu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmax((float)(x[i]), (float)0); +} + +void sigmoid_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i])); +} + +void sqrt_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sqrt(x[i]); +} + +void sin_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sin(x[i]); +} + +void cos_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::cos(x[i]); +} + +void hardsigmoid_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +void hardswish_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +void exp_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::exp(x[i]); +} + +void log_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + float xi = x[i]; + if (xi <= 0) { + dst[i] = -INFINITY; + } else { + dst[i] = sycl::log(xi); + } +} + +void neg_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = -x[i]; +} + +void step_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] > 0.0f; +} + +void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::fmax((float)(x[i]), (float)0) + + sycl::fmin((float)(x[i]), 0.0f) * negative_slope; +} + +void sqr_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + +void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { + int index = item_ct1.get_local_id(0) + + item_ct1.get_group(0) * item_ct1.get_local_range(0); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + // operation + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); +} + +void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (nidx < ne00 && item_ct1.get_group(1) < ne01 && + item_ct1.get_group(0) < ne02) { + int offset_src = nidx + item_ct1.get_group(1) * ne00 + + item_ct1.get_group(0) * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = 0.0f; + } +} + + + +void acc_f32_sycl(const float *x, const float *y, float *dst, + const int n_elements, const int ne10, const int ne11, + const int ne12, const int nb1, const int nb2, + const int offset, queue_ptr stream) { + int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, + item_ct1); + }); +} + +void gelu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_f32(x, dst, k, item_ct1); + }); +} + +void silu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + silu_f32(x, dst, k, item_ct1); + }); +} + +void gelu_quick_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_quick_f32(x, dst, k, item_ct1); + }); +} + +void tanh_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + tanh_f32(x, dst, k, item_ct1); + }); +} + +void relu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + relu_f32(x, dst, k, item_ct1); + }); +} + +void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardsigmoid_f32(x, dst, k, item_ct1); + }); +} + +void hardswish_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardswish_f32(x, dst, k, item_ct1); + }); +} + +void exp_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + exp_f32(x, dst, k, item_ct1); + }); +} + +void log_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + log_f32(x, dst, k, item_ct1); + }); +} + +void neg_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + neg_f32(x, dst, k, item_ct1); + }); +} + +void step_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + step_f32(x, dst, k, item_ct1); + }); +} + +void sigmoid_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sigmoid_f32(x, dst, k, item_ct1); + }); +} + +void sqrt_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqrt_f32(x, dst, k, item_ct1); + }); +} + +void sin_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sin_f32(x, dst, k, item_ct1); + }); +} + +void cos_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cos_f32(x, dst, k, item_ct1); + }); +} + +void leaky_relu_f32_sycl(const float *x, float *dst, const int k, + const float negative_slope, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + leaky_relu_f32(x, dst, k, negative_slope, item_ct1); + }); +} + +void sqr_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqr_f32(x, dst, k, item_ct1); + }); +} + +void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, queue_ptr stream) { + int dst_size = ne10 * ne11 * ne12 * ne13; + int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); + }); +} + +void pad_f32_sycl(const float *x, float *dst, const int ne00, + const int ne01, const int ne02, const int ne0, + const int ne1, const int ne2, queue_ptr stream) { + int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + sycl::range<3> gridDim(ne2, ne1, num_blocks); + stream->parallel_for( + sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); + }); +} + +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} +inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const float sf0 = (float)dst->ne[0]/src0->ne[0]; + const float sf1 = (float)dst->ne[1]/src0->ne[1]; + const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + pad_f32_sycl(src0_dd, dst_dd, + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); + + (void) dst; +} + +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp new file mode 100644 index 000000000..8152edf58 --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -0,0 +1,76 @@ +#ifndef GGML_SYCL_ELEMENTWISE_HPP +#define GGML_SYCL_ELEMENTWISE_HPP + +#include "common.hpp" + +static __dpct_inline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __dpct_inline__ float op_add(const float a, const float b) { + return a + b; +} + +static __dpct_inline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __dpct_inline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __dpct_inline__ float op_div(const float a, const float b) { + return a / b; +} + + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +#endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp new file mode 100644 index 000000000..c2779df0e --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -0,0 +1,55 @@ +#include +#include "outprod.hpp" + + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst) { + + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_TENSOR_BINARY_OP_LOCALS + + // Get SYCL queue + dpct::queue_ptr stream = ctx.stream(); + + // Dimension checks + GGML_ASSERT(ne01 == ne11); // Inner dimensions must match + GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows + GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols + + // Get data pointers + const float* src0_d = (const float*)src0->data; + const float* src1_d = (const float*)src1->data; + float* dst_d = (float*)dst->data; + + // GEMM parameters + const float alpha = 1.0f; + const float beta = 0.0f; + + // Handle transposition of src1 + const bool src1_T = ggml_is_transposed(src1); + const oneapi::mkl::transpose src1_op = + src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + + try { + // Perform matrix multiplication using oneMKL GEMM + oneapi::mkl::blas::gemm(*stream, + oneapi::mkl::transpose::nontrans, src1_op, + ne0, ne1, ne01, + alpha, + src0_d, ne00, + src1_d, ldb, + beta, + dst_d, ne0); + } + catch (sycl::exception const& exc) { + std::cerr << exc.what() << std::endl; + GGML_ASSERT(false); + } +} diff --git a/ggml/src/ggml-sycl/outprod.hpp b/ggml/src/ggml-sycl/outprod.hpp new file mode 100644 index 000000000..9c042738a --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.hpp @@ -0,0 +1,11 @@ +#ifndef GGML_SYCL_OUTPROD_HPP +#define GGML_SYCL_OUTPROD_HPP + +#include "common.hpp" + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst); + + +#endif // GGML_SYCL_OUTPROD_HPP + diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 340ab8e93..af1890727 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -25,6 +25,11 @@ #define SYCL_RELU_BLOCK_SIZE 256 #define SYCL_HARDSIGMOID_BLOCK_SIZE 256 #define SYCL_HARDSWISH_BLOCK_SIZE 256 +#define SYCL_EXP_BLOCK_SIZE 256 +#define SYCL_NEG_BLOCK_SIZE 256 +#define SYCL_SIGMOID_BLOCK_SIZE 256 +#define SYCL_SQRT_BLOCK_SIZE 256 +#define SYCL_SIN_BLOCK_SIZE 256 #define SYCL_SQR_BLOCK_SIZE 256 #define SYCL_CPY_BLOCK_SIZE 32 #define SYCL_SCALE_BLOCK_SIZE 256 @@ -41,6 +46,7 @@ #define SYCL_ACC_BLOCK_SIZE 256 #define SYCL_IM2COL_BLOCK_SIZE 256 #define SYCL_POOL2D_BLOCK_SIZE 256 +#define SYCL_ARGMAX_BLOCK_SIZE 256 #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256 #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp new file mode 100644 index 000000000..4c737f4bf --- /dev/null +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -0,0 +1,138 @@ +#include +#include "wkv6.hpp" + +constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE + +// Helper function for the main kernel +static void rwkv_wkv_f32_kernel( + const int B, const int T, const int C, const int H, + const float* k, const float* v, const float* r, + const float* tf, const float* td, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = WKV_BLOCK_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + // Set up shared memory pointers + float* _k = shared_mem; + float* _r = _k + head_size; + float* _tf = _r + head_size; + float* _td = _tf + head_size; + + // Local state array + float state[WKV_BLOCK_SIZE]; + + // Load initial state + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + // Sync threads before shared memory operations + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load time-mixing parameters + _tf[tid] = tf[head_i * head_size + tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Main sequence processing loop + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load current timestep data to shared memory + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0; + + // Process in chunks of 4 for better vectorization + sycl::float4 k4, r4, tf4, td4, s4, kv4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + // Load data in vec4 chunks + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute key-value product + sycl::float4 kv4 = k4 * _v; + + // Accumulate weighted sum + y += sycl::dot(r4, tf4 * kv4 + s4); + + // Update state + s4 = s4 * td4 + kv4; + + // Store updated state + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + // Save final state + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst) { + + const float* k_d = (const float*)dst->src[0]->data; + const float* v_d = (const float*)dst->src[1]->data; + const float* r_d = (const float*)dst->src[2]->data; + const float* tf_d = (const float*)dst->src[3]->data; + const float* td_d = (const float*)dst->src[4]->data; + const float* s_d = (const float*)dst->src[5]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[3]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[2]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64 + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, shared_mem_acc.get_pointer() + ); + }); + }); +} diff --git a/ggml/src/ggml-sycl/wkv6.hpp b/ggml/src/ggml-sycl/wkv6.hpp new file mode 100644 index 000000000..ddfa3377b --- /dev/null +++ b/ggml/src/ggml-sycl/wkv6.hpp @@ -0,0 +1,10 @@ +#ifndef GGML_SYCL_WKV6_HPP +#define GGML_SYCL_WKV6_HPP + +#include "common.hpp" + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor * dst); + + +#endif // GGML_SYCL_WKV6_HPP diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 266a0d6f0..bc034015f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", - "RWKV_WKV", + "RWKV_WKV6", "UNARY", @@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", - "rwkv_wkv(k, v, r, tf, td, s)", + "rwkv_wkv6(k, v, r, tf, td, s)", "unary(x)", @@ -4503,9 +4503,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -// ggml_rwkv_wkv +// ggml_rwkv_wkv6 -struct ggml_tensor * ggml_rwkv_wkv( +struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, struct ggml_tensor * k, struct ggml_tensor * v, @@ -4537,7 +4537,7 @@ struct ggml_tensor * ggml_rwkv_wkv( const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_RWKV_WKV; + result->op = GGML_OP_RWKV_WKV6; result->src[0] = k; result->src[1] = v; result->src[2] = r; @@ -6084,7 +6084,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GET_REL_POS: case GGML_OP_ADD_REL_POS: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/src/llama.cpp b/src/llama.cpp index 6719edb38..034441e1f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7011,7 +7011,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, - {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7127,7 +7127,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); } break; - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: { // FIXME const int64_t S = 123; @@ -7140,7 +7140,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * tf = w; ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); @@ -10083,7 +10083,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( v = ggml_transpose(ctx, v); r = ggml_transpose(ctx, r); - struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6cc77edab..9d48a2717 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1614,8 +1614,8 @@ struct test_ssm_scan : public test_case { } }; -// GGML_OP_RWKV_WKV -struct test_rwkv_wkv : public test_case { +// GGML_OP_RWKV_WKV6 +struct test_rwkv_wkv6 : public test_case { const ggml_type type; const int64_t head_count; @@ -1627,7 +1627,7 @@ struct test_rwkv_wkv : public test_case { return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); } - test_rwkv_wkv(ggml_type type = GGML_TYPE_F32, + test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32, int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} @@ -1639,7 +1639,7 @@ struct test_rwkv_wkv : public test_case { ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector{ head_size, head_count }.data()); ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector{ 1, head_size, head_count, n_tokens }.data()); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); - ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s); + ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s); return out; } }; @@ -3499,10 +3499,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); #if 1 for (ggml_type type_a : base_types) {