diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 496ec61c3..f329bc272 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8830,12 +8830,11 @@ static void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims -, - const sycl::nd_item<3> &item_ct1) { + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, + const float * freq_factors, const sycl::nd_item<3> &item_ct1) { const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); @@ -8863,8 +8862,10 @@ static void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; + const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + const float theta_base = - p * freq_scale * dpct::pow(theta_scale, col / 2.0f); + p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -12413,7 +12414,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, - dpct::queue_ptr stream) { + const float * freq_factors, dpct::queue_ptr stream) { GGML_ASSERT(ncols % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); @@ -12423,38 +12424,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - /* - DPCT1049:42: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, - item_ct1); - }); + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, + item_ct1); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, + item_ct1); + }); + } } else { - /* - DPCT1049:43: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, item_ct1); - }); + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); + }); + } } } @@ -13986,9 +13997,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { -#pragma message("TODO: implement phi3 frequency factors support") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") - GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + const ggml_tensor * src2 = dst->src[2]; GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -14014,6 +14023,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float * freq_factors = nullptr; const int32_t * pos = nullptr; if ((mode & 1) == 0) { GGML_ASSERT(src1->type == GGML_TYPE_I32); @@ -14024,6 +14034,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, const bool is_neox = mode & 2; const bool is_glm = mode & 4; + if (is_neox) { + pos = (const int32_t *) src1_dd; + + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + } + rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); @@ -14035,13 +14055,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, if (src0->type == GGML_TYPE_F32) { rope_neox_sycl( (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, main_stream + attn_factor, corr_dims, freq_factors, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, - main_stream); + freq_factors, main_stream); } else { GGML_ASSERT(false); }