softmax: review update

This commit is contained in:
Akarshan Biswas 2025-01-20 18:02:36 +05:30
parent 82d5c0dd80
commit b913e83c81
No known key found for this signature in database
GPG Key ID: 52A578A14B32134D

View File

@ -1,13 +1,5 @@
#include "softmax.hpp"
template <typename T> static inline float t2f32(T val) {
return static_cast<float>(val);
}
template <> inline float t2f32<sycl::half>(sycl::half val) {
return static_cast<float>(val);
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
vals[col] = val;
max_val = sycl::max(max_val, val);
@ -142,7 +134,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template, T>(x, mask, dst, ncols_par,
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0,
m1, n_head_log2, item_ct1,
get_pointer(local_buf_acc));
@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
if (n_local_scratch*sizeof(float) < local_mem_size) {
if (ncols_x > max_block_size) {
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
return;
}
switch (ncols_x) {
case 32:
soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 64:
soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 128:
soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 256:
soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 512:
soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 1024:
soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 2048:
soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 4096:
soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
default:
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
}
} else {
soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, WARP_SIZE, stream);
}