diff --git a/ggml-metal.m b/ggml-metal.m index c9e570dbf..c39f1c151 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -187,6 +187,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const return true; case GGML_OP_FLASH_ATTN_EXT: return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && @@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute( // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, // ggml_is_contiguous(src1), src1->name); //} + //if (src2) { + // GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22, + // ggml_is_contiguous(src2), src2->name); + //} //if (dst) { // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, // dst->name); @@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; + case GGML_OP_SSM_CONV: + { + id pipeline = nil; + + //pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + //[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + //[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + //[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + //[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + //[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + //[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + //[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + //[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + //[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + //[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + //[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + //[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + //[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + //[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + //[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + //[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + //[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + //[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + //[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + //[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + //[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + //[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + //[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + //[encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + //[encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + //if (bcast_row) { + // const int64_t n = ggml_nelements(dst)/4; + + // [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + //} else { + // const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + // [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + //} + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 8ff70d7a7..0ce719cb2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +kernel void kernel_ssm_conv_f32( + device const float * src0, + device const float * src1, + device const float * src2, + device const int32_t * src3, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne11, + constant int64_t & ne20, + + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb21, + constant uint64_t & nb22, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/llama.cpp b/llama.cpp index 678c49094..0a1385788 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8046,7 +8046,7 @@ static struct ggml_tensor * llm_build_mamba( // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner * n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); // extract x from x_conv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index de74585da..f4c194591 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case { } }; +// GGML_OP_SSM_CONV +struct test_ssm_conv : public test_case { + const ggml_type type_s; + const ggml_type type_x; + const ggml_type type_c; + const ggml_type type_sq; + const int64_t d_inner; + const int64_t d_conv; + const int64_t n_tokens; + const int64_t n_rs; + + std::string vars() override { + return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs); + } + + test_ssm_conv(ggml_type type_s = GGML_TYPE_F32, + ggml_type type_x = GGML_TYPE_F32, + ggml_type type_c = GGML_TYPE_F32, + ggml_type type_sq = GGML_TYPE_I32, + int64_t d_inner = 10, + int64_t d_conv = 10, + int64_t n_tokens = 10, + int64_t n_rs = 10) + : type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs); + ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens); + ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner); + ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens); + ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + // pos + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = rand() % n_rs; + } + ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_ssm_conv()); + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1));