From 8fb57ac0fbf21d09abd21f3c167ee2cec8bb7094 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 2 Jun 2024 22:49:24 -0400 Subject: [PATCH] llama : use im2col and mul_mat to perform convolution for Mamba This removes the need for ggml_ssm_conv!!! But performance seems slighly worse on my system, especially for prompt processing. Maybe ggml_mul_mat isn't optimized for small row sizes? More performance testing is necessary until GGML_OP_SSM_CONV is removed. * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. --- ggml.c | 121 +++++++++++++++++++++--------------------------------- ggml.h | 3 +- llama.cpp | 83 +++++++++++++++++++++++++------------ 3 files changed, 103 insertions(+), 104 deletions(-) diff --git a/ggml.c b/ggml.c index 426501015..253b3fa41 100644 --- a/ggml.c +++ b/ggml.c @@ -7124,26 +7124,24 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c) { - GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_3d(x)); + GGML_ASSERT(ggml_is_3d(sx)); GGML_ASSERT(ggml_is_matrix(c)); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; - const int64_t n_t = x->ne[1]; // tokens per sequence - const int64_t n_s = s->ne[2]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - GGML_ASSERT(s->ne[0] == d_conv - 1); - GGML_ASSERT(s->ne[1] == d_inner); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(x->ne[2] == n_s); + // TODO: maybe support other strides than 1? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); bool is_node = false; - if (s->grad || x->grad || c->grad) { + if (sx->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -7152,9 +7150,8 @@ struct ggml_tensor * ggml_ssm_conv( result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = s; - result->src[1] = x; - result->src[2] = c; + result->src[0] = sx; + result->src[1] = c; return result; } @@ -7203,8 +7200,8 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // y - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); + // concatenated y + ssm_states + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -16252,22 +16249,21 @@ static void ggml_compute_forward_ssm_conv_f32( return; } - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // tokens per sequence - const int n_s = src0->ne[2]; // number of sequences in the batch + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_are_same_shape(src1, dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // rows per thread @@ -16278,54 +16274,28 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? - // This would avoid having to copy into an intermediate buffer, but the state would be bigger. - float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - for (int i3 = 0; i3 < n_s; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - - // copy the state into working memory - // can't use memcpy because (d_conv) != (d_conv - 1) - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - for (int i2 = 0; i2 < n_t; ++i2) { - float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - - // shift state left - memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + // TODO: transpose the output for smaller strides for big batches? // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } - - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + + // d_conv for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; } } - - // copy the state out of it - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; - } - } } } @@ -16368,7 +16338,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16377,6 +16347,10 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src5->nb[0] == sizeof(float)); // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16388,13 +16362,17 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } // d_inner for (int i1 = 0; i1 < ir; ++i1) { @@ -16406,7 +16384,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[i0]; s[i] = state; @@ -19577,13 +19555,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; - case GGML_OP_SSM_CONV: - { - const int64_t d_conv = node->src[2]->ne[0]; - const int64_t d_inner = node->src[0]->ne[1]; - - cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); - } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 9df601e2c..c772febf0 100644 --- a/ggml.h +++ b/ggml.h @@ -1803,8 +1803,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( diff --git a/llama.cpp b/llama.cpp index ce96d7b55..ecdcf3a4e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2827,11 +2827,13 @@ struct llama_rs_cache { n_shared_tail_cells += 1; n_seqs -= 1; } - } else if (rs_cell.is_empty()) { - // from shared to unique - n_seqs += 1; - if (prev_cell.tail_rc == 1) { - // it was the last tail of the previous cell + } else { + if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + } + if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + // from last shared to fully tail n_shared_tail_cells -= 1; } } @@ -8683,6 +8685,18 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + // copy states which won't be changed further (between n_seqs and n_rs) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, conv_states, (d_conv - 1)*d_inner*(n_rs - n_seqs), n_seqs*(conv_states->nb[2])), + ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*(d_conv - 1)*d_inner*ggml_element_size(conv_states_all)))); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, ssm_states, d_state*d_inner*(n_rs - n_seqs), n_seqs*(ssm_states->nb[2])), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + // the part of the states that will be used and modified struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); @@ -8698,28 +8712,43 @@ static struct ggml_tensor * llm_build_mamba( // conv { - // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. - // And also because this uses much less memory for large batches (4 times less when d_conv is 4). - // The equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, // then sum the elements of each row, // (the last two steps are a dot product over rows (also doable with mul_mat)) // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. // For simultaneous sequences, all sequences need to have the same length. - x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); - // ensure conv is updated before copying into the recurrent state cache - ggml_build_forward_expand(graph, x); + // For some reason, im2col expects a F16 kernel, but doesn't even read from it. + // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. + // => { d_conv * d_inner, n_seq_tokens, n_seqs} + x = ggml_im2col(ctx, + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), + conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); - ggml_build_forward_expand(graph, - ggml_cpy(ctx, conv_states, - ggml_view_1d(ctx, conv_states_all, - (d_conv - 1)*(d_inner)*(n_rs), - rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); + + // => {1, 1, d_inner, n_seq_tokens * n_seqs} + x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); + + // Alternatively, this does the same as the above + // x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8746,16 +8775,16 @@ static struct ggml_tensor * llm_build_mamba( // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} - struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - - // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache - ggml_build_forward_expand(graph, y); + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); // store last states ggml_build_forward_expand(graph, - ggml_cpy(ctx, ssm_states, - ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); // TODO: skip computing output earlier for unused tokens