diff --git a/CMakeLists.txt b/CMakeLists.txt index d26aedaf3..aa453b6b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -490,7 +490,8 @@ if (LLAMA_KOMPUTE) kompute/op_getrows_q4_0.comp kompute/op_getrows_q4_1.comp kompute/op_getrows_q6_k.comp - kompute/op_rope.comp + kompute/op_rope_f16.comp + kompute/op_rope_f32.comp kompute/op_cpy_f16_f16.comp kompute/op_cpy_f16_f32.comp kompute/op_cpy_f32_f16.comp @@ -521,7 +522,8 @@ if (LLAMA_KOMPUTE) shaderop_getrows_q4_0.h shaderop_getrows_q4_1.h shaderop_getrows_q6_k.h - shaderop_rope.h + shaderop_rope_f16.h + shaderop_rope_f32.h shaderop_cpy_f16_f16.h shaderop_cpy_f16_f32.h shaderop_cpy_f32_f16.h diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 01d70d1a6..3e3f6cc80 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -32,7 +32,8 @@ #include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_1.h" #include "shaderop_getrows_q6_k.h" -#include "shaderop_rope.h" +#include "shaderop_rope_f16.h" +#include "shaderop_rope_f32.h" #include "shaderop_cpy_f16_f16.h" #include "shaderop_cpy_f16_f32.h" #include "shaderop_cpy_f32_f16.h" @@ -1175,51 +1176,66 @@ void ggml_vk_get_rows_q6_k(Args&&... args) { ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward(args)...); } -void ggml_vk_rope(kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - uint32_t n_past, int32_t n_dims, int32_t mode, - float freq_base, float freq_scale, - int32_t ne01, int32_t ne02, int32_t ne03, - uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, - int32_t ne0, - uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3) { - const static auto spirv = getSpirvShader(kp::shader_data::op_rope_comp_spv, - kp::shader_data::op_rope_comp_spv_len); +void ggml_vk_rope( + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + ggml_type src0t, int32_t n_dims, int32_t mode, + float freq_base, float freq_scale, + int32_t ne01, int32_t ne02, int32_t ne03, + uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, + int32_t ne0, + uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3 +) { + GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32); - GGML_ASSERT(nb03%sizeof(float) == 0); - GGML_ASSERT(nb02%sizeof(float) == 0); - GGML_ASSERT(nb01%sizeof(float) == 0); - GGML_ASSERT(nb00%sizeof(float) == 0); - GGML_ASSERT(nb3%sizeof(float) == 0); - GGML_ASSERT(nb2%sizeof(float) == 0); - GGML_ASSERT(nb1%sizeof(float) == 0); - GGML_ASSERT(nb0%sizeof(float) == 0); + static const auto spirv_f16 = getSpirvShader( + kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len + ); + static const auto spirv_f32 = getSpirvShader( + kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len + ); + + int type_size = src0t == GGML_TYPE_F16 ? 2 : 4; + + GGML_ASSERT(nb03 % type_size == 0); + GGML_ASSERT(nb02 % type_size == 0); + GGML_ASSERT(nb01 % type_size == 0); + GGML_ASSERT(nb00 % type_size == 0); + GGML_ASSERT(nb3 % type_size == 0); + GGML_ASSERT(nb2 % type_size == 0); + GGML_ASSERT(nb1 % type_size == 0); + GGML_ASSERT(nb0 % type_size == 0); struct PushConstants { - uint32_t inOff, outOff; - uint32_t n_past; + uint32_t inAOff, inBOff, outOff; int32_t n_dims, mode; float freq_base, freq_scale; uint32_t nb00, nb01, nb02, nb03; int32_t ne0; uint32_t nb0, nb1, nb2, nb3; } pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - n_past, n_dims, mode, + safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), + n_dims, mode, freq_base, freq_scale, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 }; + auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32"); std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({in, out}); + if (!komputeManager()->hasAlgorithm(name)) { + s_algo = komputeManager()->algorithm( + name, s_kompute_context->pool.get(), {inA, inB, out}, + src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32, + {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts} + ); + } else { + s_algo = komputeManager()->getAlgorithm(name); + s_algo->setTensors({inA, inB, out}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); @@ -1506,14 +1522,16 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_ROPE: { - const int n_past = ((int32_t *) dst->op_params)[0]; + GGML_ASSERT(ne10 == ne02); + GGML_ASSERT(src0t == dstt); + // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; float freq_base; float freq_scale; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - ggml_vk_rope(seq, id_src0, id_dst, off_src0, off_dst, n_past, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3); + ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3); } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/kompute/op_rope_f16.comp b/kompute/op_rope_f16.comp new file mode 100644 index 000000000..fd3943c81 --- /dev/null +++ b/kompute/op_rope_f16.comp @@ -0,0 +1,89 @@ +/** + * Copyright (c) 2023 Nomic, Inc. All rights reserved. + * + * This software is licensed under the terms of the Software for Open Models License (SOM), + * version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany + * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. + */ + +#version 450 + +#include "common.comp" + +// TODO: use a local size of 32 or more (Metal uses 1024) +layout(local_size_x = 1) in; + +layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int n_dims; + int mode; + float freq_base; + float freq_scale; + uint nb00; + uint nb01; + uint nb02; + uint nb03; + int ne0; + uint nb0; + uint nb1; + uint nb2; + uint nb3; +} pcs; + +void main() { + const uint i3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_WorkGroupID.x; + + const bool is_neox = (pcs.mode & 2) != 0; + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); + + const int p = inB[pcs.inBOff + i2]; + + float theta = pcs.freq_scale * float(p); + + if (!is_neox) { + for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + const float x0 = float(inA[src]); + const float x1 = float(inA[src+1]); + + out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); + out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta); + } + } else { + const float inv_ndims = -1.f/pcs.n_dims; + for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) { + for (uint ic = 0; ic < pcs.n_dims; ic += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const uint i0 = ib*pcs.n_dims + ic/2; + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + const float x0 = float(inA[src]); + const float x1 = float(inA[src+pcs.n_dims/2]); + + out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); + out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta); + } + } + } +} diff --git a/kompute/op_rope.comp b/kompute/op_rope_f32.comp similarity index 78% rename from kompute/op_rope.comp rename to kompute/op_rope_f32.comp index 8c2854636..6024c3e5e 100644 --- a/kompute/op_rope.comp +++ b/kompute/op_rope_f32.comp @@ -12,13 +12,14 @@ layout(local_size_x = 1) in; -layout (binding = 0) readonly buffer tensorIn { float in_[]; }; -layout (binding = 1) writeonly buffer tensorOut { float out_[]; }; +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; layout (push_constant) uniform parameter { - uint inOff; + uint inAOff; + uint inBOff; uint outOff; - uint n_past; int n_dims; int mode; float freq_base; @@ -42,7 +43,7 @@ void main() { const bool is_neox = (pcs.mode & 2) != 0; const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - const uint p = ((pcs.mode & 1) == 0 ? pcs.n_past + i2 : i2); + const int p = inB[pcs.inBOff + i2]; float theta = pcs.freq_scale * float(p); @@ -53,11 +54,11 @@ void main() { theta *= theta_scale; - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - const float x0 = in_[src]; - const float x1 = in_[src+1]; + const float x0 = inA[src]; + const float x1 = inA[src+1]; out_[dst_data] = x0*cos_theta - x1*sin_theta; out_[dst_data+1] = x0*sin_theta + x1*cos_theta; @@ -73,11 +74,11 @@ void main() { const uint i0 = ib*pcs.n_dims + ic/2; - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - const float x0 = in_[src]; - const float x1 = in_[src+pcs.n_dims/2]; + const float x0 = inA[src]; + const float x1 = inA[src+pcs.n_dims/2]; out_[dst_data] = x0*cos_theta - x1*sin_theta; out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; diff --git a/llama.cpp b/llama.cpp index a56ffce9f..8455424b4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2772,8 +2772,9 @@ static struct ggml_cgraph * llm_build_llama( } // shift the entire K-cache if needed + struct ggml_tensor * K_shift = nullptr; if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); offload_func_kq(K_shift); ggml_set_name(K_shift, "K_shift"); ggml_allocr_alloc(lctx.alloc, K_shift); @@ -3024,6 +3025,11 @@ static struct ggml_cgraph * llm_build_llama( ggml_vk_h2d_all(lctx.ctx_kompute); } else { ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor); + ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_pos); + ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask); + if (K_shift) { + ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift); + } } } #endif @@ -3589,8 +3595,9 @@ static struct ggml_cgraph * llm_build_falcon( } // shift the entire K-cache if needed + struct ggml_tensor * K_shift = nullptr; if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); offload_func_kq(K_shift); ggml_set_name(K_shift, "K_shift"); ggml_allocr_alloc(lctx.alloc, K_shift); @@ -3820,6 +3827,11 @@ static struct ggml_cgraph * llm_build_falcon( ggml_vk_h2d_all(lctx.ctx_kompute); } else { ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor); + ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_pos); + ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask); + if (K_shift) { + ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift); + } } } #endif