diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index fb2af7b1f..4a8785b79 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,13 +1,5 @@ #include "softmax.hpp" -template static inline float t2f32(T val) { - return static_cast(val); -} - -template <> inline float t2f32(sycl::half val) { - return static_cast(val); -} - template static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, @@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -142,7 +134,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, + soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, get_pointer(local_buf_acc)); @@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask, const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); return; } switch (ncols_x) { case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; } } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, WARP_SIZE, stream); }