llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators
This commit is contained in:
Francis Couture-Harpin 2024-07-16 20:33:45 -04:00
parent a38b884c6c
commit c51daefc32
3 changed files with 1056 additions and 643 deletions

View File

@ -1760,10 +1760,8 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_ssm_conv( GGML_API struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x, struct ggml_tensor * c);
struct ggml_tensor * c,
struct ggml_tensor * sq);
GGML_API struct ggml_tensor * ggml_ssm_scan( GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -1772,8 +1770,7 @@ extern "C" {
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C);
struct ggml_tensor * sq);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View File

@ -7082,43 +7082,34 @@ struct ggml_tensor * ggml_flash_attn_back(
struct ggml_tensor * ggml_ssm_conv( struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x, struct ggml_tensor * c) {
struct ggml_tensor * c, GGML_ASSERT(ggml_is_3d(sx));
struct ggml_tensor * sq) {
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(ggml_is_matrix(x));
GGML_ASSERT(ggml_is_matrix(c)); GGML_ASSERT(ggml_is_matrix(c));
GGML_ASSERT(ggml_is_matrix(sq));
GGML_ASSERT(sq->type == GGML_TYPE_I32);
const int64_t d_conv = c->ne[0]; const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1]; const int64_t d_inner = c->ne[1];
const int64_t n_tokens = x->ne[1]; const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
const int64_t n_kv = s->ne[2]; const int64_t n_s = sx->ne[2];
GGML_ASSERT( s->ne[0] == d_conv - 1); // TODO: maybe support other strides than 1?
GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT( x->ne[0] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(sq->ne[0] == n_kv); GGML_ASSERT(n_t >= 0);
GGML_ASSERT(sq->ne[1] == n_tokens);
bool is_node = false; bool is_node = false;
if (s->grad || x->grad || c->grad || sq->grad) { if (sx->grad || c->grad) {
GGML_ASSERT(false); // TODO: implement GGML_ASSERT(false); // TODO: implement
is_node = true; is_node = true;
} }
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
result->op = GGML_OP_SSM_CONV; result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = s; result->src[0] = sx;
result->src[1] = x; result->src[1] = c;
result->src[2] = c;
result->src[3] = sq;
return result; return result;
} }
@ -7132,39 +7123,42 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C) {
struct ggml_tensor * sq) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(sq->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_matrix(A));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt)); GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(ggml_are_same_shape(B, C));
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1]; const int64_t d_inner = s->ne[1];
const int64_t n_tokens = x->ne[1]; const int64_t n_seq_tokens = x->ne[1];
const int64_t n_seqs = x->ne[2];
GGML_ASSERT(s->ne[2] == n_seqs);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_tokens); GGML_ASSERT(B->ne[1] == n_seq_tokens);
GGML_ASSERT(C->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seqs);
GGML_ASSERT(C->ne[1] == n_tokens);
} }
bool is_node = false; bool is_node = false;
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
GGML_ASSERT(false); // TODO: implement GGML_ASSERT(false); // TODO: implement
is_node = true; is_node = true;
} }
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
@ -7175,7 +7169,6 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A; result->src[3] = A;
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = sq;
return result; return result;
} }
@ -10839,11 +10832,6 @@ static void ggml_compute_forward_concat_f32(
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
const int32_t dim = ggml_get_op_params_i32(dst, 0); const int32_t dim = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
@ -15546,27 +15534,22 @@ static void ggml_compute_forward_flash_attn_back(
static void ggml_compute_forward_ssm_conv_f32( static void ggml_compute_forward_ssm_conv_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nc = src2->ne[0]; // d_conv const int nc = src1->ne[0]; // d_conv
const int nr = src0->ne[1]; // d_inner const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int n_t = src1->ne[1]; // n_tokens const int nr = src0->ne[1]; // d_inner
const int n_kv = src0->ne[2]; // max number of sequences in the batch const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// for use with the destination state offset between sequences
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -15576,76 +15559,29 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { for (int i3 = 0; i3 < n_s; ++i3) {
// multiple sequences means it's hard to know when it's the first time a state is read, for (int i2 = 0; i2 < n_t; ++i2) {
// so copy them all over to the destination, just to be sure. // {d_conv - 1 + n_t, d_inner, n_seqs}
for (int i3 = 0; i3 < n_kv; ++i3) { // sliding window
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
// can't use memcpy because of d_conv vs d_conv - 1 float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) { // rowwise dot product
// copy s0 to last (d_conv - 1) columns of s // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; float sumf = 0.0f;
// d_conv
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
} }
x[i1] = sumf;
} }
} }
} }
for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
float * s0; // {d_conv - 1, d_inner, n_kv}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int ne0s0;
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
ne0s0 = src0->ne[0];
} else {
// the source is the last (d_conv - 1) columns of the destination
s0 = s + 1;
ne0s0 = nc;
}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// shift state left
for (int i0 = 0; i0 < nc - 1; ++i0) {
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
}
// insert x on the last column
s[(nc - 1) + i1*nc] = x0[i1];
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
sumf += s[i] * c[i];
}
x[i1] = sumf;
}
}
} }
static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_conv(
@ -15674,15 +15610,14 @@ static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // sq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens in the batch const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -15691,12 +15626,12 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C, and when copying the states // required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states // required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[2]) // required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -15706,64 +15641,36 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { for (int i3 = 0; i3 < n_s; ++i3) {
// it's hard to know if the source states have already been copied for (int i2 = 0; i2 < n_t; ++i2) {
// when there are multiple, so copy them already. const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
for (int i3 = 0; i3 < n_kv; ++i3) { const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
memcpy(s, s0, nc*ir*sizeof(float)); const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
} float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
for (int i2 = 0; i2 < n_t; ++i2) { // use the output as the source for the next token-wise iterations
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} if (i2 > 0) { s0 = s; }
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); // d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// avoid needing to copy the state for the first token // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
if (i2 == 0) { float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} float x_dt = x[i1] * dt_soft_plus;
} else { float sumf = 0.0f;
// otherwise the source is the same as the destination // d_state
s0 = s; for (int i0 = 0; i0 < nc; ++i0) {
} int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
// d_inner float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
for (int i1 = 0; i1 < ir; ++i1) { // y = rowwise_dotprod(state, C)
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 sumf += state * C[i0];
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; s[i] = state;
float x_dt = x[i1] * dt_soft_plus; }
float sumf = 0.0f; y[i1] = sumf;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
}
y[i1] = sumf;
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
} }
} }
} }

File diff suppressed because it is too large Load Diff