diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 4747850cf..239f913f5 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1356,7 +1356,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { if (ggml_nelements(src1) == ne10) { // src1 is a row - ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); + ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); } else { ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); } @@ -1365,7 +1365,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { if (ggml_nelements(src1) == ne10) { // src1 is a row - ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); + ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); } else { ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); } @@ -1373,7 +1373,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph case GGML_OP_SCALE: { const float scale = *(const float *) src1->data; - ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale); + ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale); } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { @@ -1387,7 +1387,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_UNARY_OP_GELU: { - ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); + ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8); } break; default: { diff --git a/kompute/op_addrow.comp b/kompute/op_addrow.comp index 926c929e4..bf674f829 100644 --- a/kompute/op_addrow.comp +++ b/kompute/op_addrow.comp @@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff]; -} \ No newline at end of file + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff]; + } +} diff --git a/kompute/op_gelu.comp b/kompute/op_gelu.comp index f74a14f7e..1412ee1ab 100644 --- a/kompute/op_gelu.comp +++ b/kompute/op_gelu.comp @@ -20,9 +20,9 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; + const uint baseIndex = gl_WorkGroupID.x * 8; - for (uint x = 0; x < 4; x++) { + for (uint x = 0; x < 8; x++) { const uint i = baseIndex + x; const float y = in_[i + pcs.inOff]; out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y))); diff --git a/kompute/op_mulrow.comp b/kompute/op_mulrow.comp index 498dbdfcd..955fe26bf 100644 --- a/kompute/op_mulrow.comp +++ b/kompute/op_mulrow.comp @@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff]; + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff]; + } } \ No newline at end of file diff --git a/kompute/op_scale.comp b/kompute/op_scale.comp index 8530aaf3e..2ec524435 100644 --- a/kompute/op_scale.comp +++ b/kompute/op_scale.comp @@ -22,7 +22,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 8; - out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; + for (uint x = 0; x < 8; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; + } } \ No newline at end of file diff --git a/kompute/op_silu.comp b/kompute/op_silu.comp index 8c7bfe321..9233fd5a1 100644 --- a/kompute/op_silu.comp +++ b/kompute/op_silu.comp @@ -18,8 +18,8 @@ layout(push_constant) uniform PushConstants { uint inOff; uint outOff; } pcs; -void main() { +void main() { const uint baseIndex = gl_WorkGroupID.x * 4; for (uint x = 0; x < 4; x++) { diff --git a/kompute/src/Algorithm.cpp b/kompute/src/Algorithm.cpp index ea81fd97b..f8f1c7e36 100644 --- a/kompute/src/Algorithm.cpp +++ b/kompute/src/Algorithm.cpp @@ -387,6 +387,10 @@ Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer) void Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize) { + if (workgroup[0] > 65535) { + fprintf(stderr, "workgroup size is %d\n", workgroup[0]); + fflush(stderr); + } KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size");