From 696faa866059cb6e227ac3543bb274adac88b8ab Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 18:49:39 +0200 Subject: [PATCH] kompute : fix rope_f32 and scale ops (#5008) --- ggml-kompute.cpp | 3 ++- kompute-shaders/op_rope_f32.comp | 38 +++++++++++++++++++------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 10f94f18c..0f0003c48 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1540,7 +1540,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_SCALE: { - const float scale = *(const float *) src1->data; + float scale; memcpy(&scale, dst->op_params, sizeof(float)); + ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale); } break; case GGML_OP_UNARY: diff --git a/kompute-shaders/op_rope_f32.comp b/kompute-shaders/op_rope_f32.comp index 104ae0ba4..2adf5eb4e 100644 --- a/kompute-shaders/op_rope_f32.comp +++ b/kompute-shaders/op_rope_f32.comp @@ -35,31 +35,39 @@ void main() { const float x0 = inA[src]; const float x1 = inA[src+1]; - out_[dst_data] = x0*cos_theta - x1*sin_theta; + out_[dst_data] = x0*cos_theta - x1*sin_theta; out_[dst_data+1] = x0*sin_theta + x1*cos_theta; } } else { const float inv_ndims = -1.f/pcs.n_dims; - for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) { - for (uint ic = 0; ic < pcs.n_dims; ic += 2) { - const uint cur_rot = ib * pcs.n_dims + ic; + for (uint ic = 0; ic < pcs.n_dims; ic += 2) { + const uint cur_rot = ic; - float cos_theta, sin_theta; - rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); + float cos_theta, sin_theta; + rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - theta *= theta_scale; + theta *= theta_scale; - const uint i0 = ib*pcs.n_dims + ic/2; + const uint i0 = ic/2; - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - const float x0 = inA[src]; - const float x1 = inA[src+pcs.n_dims/2]; + const float x0 = inA[src]; + const float x1 = inA[src+pcs.n_dims/2]; - out_[dst_data] = x0*cos_theta - x1*sin_theta; - out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; - } + out_[dst_data] = x0*cos_theta - x1*sin_theta; + out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; + } + + for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) { + const uint i0 = ic; + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + out_[dst_data + 0] = inA[src + 0]; + out_[dst_data + 1] = inA[src + 1]; } } }