llama : use im2col and mul_mat to perform convolution for Mamba

This removes the need for ggml_ssm_conv!!!
But performance seems slighly worse on my system,
especially for prompt processing.
Maybe ggml_mul_mat isn't optimized for small row sizes?
More performance testing is necessary until GGML_OP_SSM_CONV is removed.

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.
This commit is contained in:
Francis Couture-Harpin 2024-06-02 22:49:24 -04:00
parent eb589d5e36
commit 8fb57ac0fb
3 changed files with 103 additions and 104 deletions

121
ggml.c
View File

@ -7124,26 +7124,24 @@ struct ggml_tensor * ggml_flash_attn_back(
struct ggml_tensor * ggml_ssm_conv( struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x,
struct ggml_tensor * c) { struct ggml_tensor * c) {
GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_3d(sx));
GGML_ASSERT(ggml_is_3d(x));
GGML_ASSERT(ggml_is_matrix(c)); GGML_ASSERT(ggml_is_matrix(c));
const int64_t d_conv = c->ne[0]; const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1]; const int64_t d_inner = c->ne[1];
const int64_t n_t = x->ne[1]; // tokens per sequence const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
const int64_t n_s = s->ne[2]; const int64_t n_s = sx->ne[2];
GGML_ASSERT(s->ne[0] == d_conv - 1); // TODO: maybe support other strides than 1?
GGML_ASSERT(s->ne[1] == d_inner); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(x->ne[2] == n_s); GGML_ASSERT(n_t >= 0);
bool is_node = false; bool is_node = false;
if (s->grad || x->grad || c->grad) { if (sx->grad || c->grad) {
GGML_ASSERT(false); // TODO: implement GGML_ASSERT(false); // TODO: implement
is_node = true; is_node = true;
} }
@ -7152,9 +7150,8 @@ struct ggml_tensor * ggml_ssm_conv(
result->op = GGML_OP_SSM_CONV; result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = s; result->src[0] = sx;
result->src[1] = x; result->src[1] = c;
result->src[2] = c;
return result; return result;
} }
@ -7203,8 +7200,8 @@ struct ggml_tensor * ggml_ssm_scan(
is_node = true; is_node = true;
} }
// y // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -16252,22 +16249,21 @@ static void ggml_compute_forward_ssm_conv_f32(
return; return;
} }
const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nc = src2->ne[0]; // d_conv const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner const int nr = src0->ne[1]; // d_inner
const int n_t = src1->ne[1]; // tokens per sequence const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = src0->ne[2]; // number of sequences in the batch const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_are_same_shape(src1, dst)); GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// rows per thread // rows per thread
@ -16278,54 +16274,28 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
for (int i3 = 0; i3 < n_s; ++i3) { for (int i3 = 0; i3 < n_s; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
// copy the state into working memory
// can't use memcpy because (d_conv) != (d_conv - 1)
for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) {
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
}
}
for (int i2 = 0; i2 < n_t; ++i2) { for (int i2 = 0; i2 < n_t; ++i2) {
float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} // {d_conv - 1 + n_t, d_inner, n_seqs}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} // sliding window
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
// shift state left float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
// TODO: transpose the output for smaller strides for big batches?
// d_inner // d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// insert x on the last column
s[(nc - 1) + i1*nc] = x0[i1];
}
// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product // rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f; float sumf = 0.0f;
// d_conv
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc; sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
sumf += s[i] * c[i];
} }
x[i1] = sumf; x[i1] = sumf;
} }
} }
// copy the state out of it
for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) {
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
}
}
} }
} }
@ -16368,7 +16338,7 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -16377,6 +16347,10 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C // required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -16388,13 +16362,17 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i3 = 0; i3 < n_s; ++i3) { for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) { for (int i2 = 0; i2 < n_t; ++i2) {
float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
// d_inner // d_inner
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
@ -16406,7 +16384,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc; int i = i0 + i1*nc;
// state = prev_state * dA + dB * x // state = prev_state * dA + dB * x
float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C) // y = rowwise_dotprod(state, C)
sumf += state * C[i0]; sumf += state * C[i0];
s[i] = state; s[i] = state;
@ -19577,13 +19555,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} }
} break; } break;
case GGML_OP_SSM_CONV:
{
const int64_t d_conv = node->src[2]->ne[0];
const int64_t d_inner = node->src[0]->ne[1];
cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
{ {
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);

3
ggml.h
View File

@ -1803,8 +1803,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_ssm_conv( GGML_API struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x,
struct ggml_tensor * c); struct ggml_tensor * c);
GGML_API struct ggml_tensor * ggml_ssm_scan( GGML_API struct ggml_tensor * ggml_ssm_scan(

View File

@ -2827,11 +2827,13 @@ struct llama_rs_cache {
n_shared_tail_cells += 1; n_shared_tail_cells += 1;
n_seqs -= 1; n_seqs -= 1;
} }
} else if (rs_cell.is_empty()) { } else {
// from shared to unique if (rs_cell.is_empty()) {
n_seqs += 1; // from shared to unique
if (prev_cell.tail_rc == 1) { n_seqs += 1;
// it was the last tail of the previous cell }
if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) {
// from last shared to fully tail
n_shared_tail_cells -= 1; n_shared_tail_cells -= 1;
} }
} }
@ -8683,6 +8685,18 @@ static struct ggml_tensor * llm_build_mamba(
conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs);
ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs);
// copy states which won't be changed further (between n_seqs and n_rs)
ggml_build_forward_expand(graph,
ggml_cpy(ctx,
ggml_view_1d(ctx, conv_states, (d_conv - 1)*d_inner*(n_rs - n_seqs), n_seqs*(conv_states->nb[2])),
ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*(d_conv - 1)*d_inner*ggml_element_size(conv_states_all))));
ggml_build_forward_expand(graph,
ggml_cpy(ctx,
ggml_view_1d(ctx, ssm_states, d_state*d_inner*(n_rs - n_seqs), n_seqs*(ssm_states->nb[2])),
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*d_state*d_inner*ggml_element_size(ssm_states_all))));
// the part of the states that will be used and modified
struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0);
struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0);
@ -8698,28 +8712,43 @@ static struct ggml_tensor * llm_build_mamba(
// conv // conv
{ {
// Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
// And also because this uses much less memory for large batches (4 times less when d_conv is 4). struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0);
// The equivalent is to concatenate the columns of conv_states and x,
// then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, // copy last (d_conv - 1) columns back into the state cache
// then element-wise multiply that with the conv1d weigth, struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
ggml_build_forward_expand(graph,
ggml_cpy(ctx, last_conv,
ggml_view_1d(ctx, conv_states_all,
(d_conv - 1)*(d_inner)*(n_seqs),
rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
// over d_conv columns at each stride in the 3rd dimension,
// then element-wise multiply that with the conv1d weight,
// then sum the elements of each row, // then sum the elements of each row,
// (the last two steps are a dot product over rows (also doable with mul_mat)) // (the last two steps are a dot product over rows (also doable with mul_mat))
// then permute away the ne[0] dimension, // then permute away the ne[0] dimension,
// and then you're left with the resulting x tensor. // and then you're left with the resulting x tensor.
// The new conv_states is the last (d_conv - 1) columns
// of the last 3rd dimensional "layer" of the self-overlapping view.
// For simultaneous sequences, all sequences need to have the same length. // For simultaneous sequences, all sequences need to have the same length.
x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d);
// ensure conv is updated before copying into the recurrent state cache // For some reason, im2col expects a F16 kernel, but doesn't even read from it.
ggml_build_forward_expand(graph, x); // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it.
// => { d_conv * d_inner, n_seq_tokens, n_seqs}
x = ggml_im2col(ctx,
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner),
conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32);
ggml_build_forward_expand(graph, x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs);
ggml_cpy(ctx, conv_states,
ggml_view_1d(ctx, conv_states_all, // => {1, 1, d_inner, n_seq_tokens * n_seqs}
(d_conv - 1)*(d_inner)*(n_rs), x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x);
rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs);
// Alternatively, this does the same as the above
// x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
// bias // bias
x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
@ -8746,16 +8775,16 @@ static struct ggml_tensor * llm_build_mamba(
// Custom operator to optimize the parallel associative scan // Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper. // as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
// The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache
ggml_build_forward_expand(graph, y);
// store last states // store last states
ggml_build_forward_expand(graph, ggml_build_forward_expand(graph,
ggml_cpy(ctx, ssm_states, ggml_cpy(ctx,
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens