diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 022e7b763..e663c9c11 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -79,7 +79,7 @@ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *sr float * src0_ddf = (float *) src0->data; float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; float * dst_ddf = (float *) dst->data; - + /* These are never used */ ggml_sycl_pool_alloc src0_f(ctx.pool()); ggml_sycl_pool_alloc src1_f(ctx.pool()); ggml_sycl_pool_alloc dst_f(ctx.pool()); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 0d097357c..6a1c1ea5c 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -2,6 +2,7 @@ #include "dmmv.hpp" #include "dequantize.hpp" #include "presets.hpp" +#include "ggml-impl.h" static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ @@ -973,6 +974,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion + GGML_UNUSED(ctx); #endif // GGML_SYCL_F16 switch (src0->type) { @@ -1010,7 +1012,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; default: - printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); + GGML_LOG_ERROR("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); break; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ed4d8bb8b..2984ed82e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max); -} - static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); @@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - ggml_sycl_soft_max(ctx, dst); + ggml_sycl_op_soft_max(ctx, dst); break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index a9b3fce0d..fb2af7b1f 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,7 +1,15 @@ -#include "norm.hpp" +#include "softmax.hpp" -template -static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, +template static inline float t2f32(T val) { + return static_cast(val); +} + +template <> inline float t2f32(sycl::half val) { + return static_cast(val); +} + +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -29,9 +37,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; +#pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -42,7 +51,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -65,7 +74,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); } @@ -122,8 +131,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } } -template -static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { @@ -133,7 +142,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, + soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, get_pointer(local_buf_acc)); @@ -141,7 +150,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * }); } -static void soft_max_f32_sycl(const float * x, const float * mask, +template +static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, queue_ptr stream, int device) { @@ -164,81 +174,75 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); return; } switch (ncols_x) { case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; } } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, WARP_SIZE, stream); } } -void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; float scale = 1.0f; float max_bias = 0.0f; @@ -246,6 +250,23 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + //printf("%s: fp16 mask\n", __func__); + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + //printf("%s: fp32 mask\n", __func__); + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } } diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index bdb8f712e..2cf8582ec 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,10 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream); +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp index b54c20964..756437bcf 100644 --- a/ggml/src/ggml-sycl/wkv6.cpp +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -97,9 +97,6 @@ static void rwkv_wkv_f32_kernel( void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - const float* k_d = (const float*)dst->src[0]->data; const float* v_d = (const float*)dst->src[1]->data; const float* r_d = (const float*)dst->src[2]->data; @@ -137,7 +134,4 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { ); }); }); - - GGML_UNUSED(src0); - GGML_UNUSED(src1); }