mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 14:28:58 +01:00
llama : use equal-sequence-length sub-batches for recurrent models
* ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch
This commit is contained in:
parent
4e4c41e553
commit
3587a94987
166
ggml.c
166
ggml.c
@ -7103,40 +7103,35 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq) {
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(ggml_is_matrix(x));
|
||||
GGML_ASSERT(ggml_is_3d(x));
|
||||
GGML_ASSERT(ggml_is_matrix(c));
|
||||
GGML_ASSERT(ggml_is_vector(sq));
|
||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t d_conv = c->ne[0];
|
||||
const int64_t d_inner = c->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
const int64_t n_rs = s->ne[2];
|
||||
const int64_t n_t = x->ne[1]; // tokens per sequence
|
||||
const int64_t n_s = s->ne[2];
|
||||
|
||||
GGML_ASSERT(s->ne[0] == d_conv - 1);
|
||||
GGML_ASSERT(s->ne[1] == d_inner);
|
||||
GGML_ASSERT(x->ne[0] == d_inner);
|
||||
GGML_ASSERT(sq->ne[0] == n_tokens);
|
||||
GGML_ASSERT(x->ne[2] == n_s);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || x->grad || c->grad || sq->grad) {
|
||||
if (s->grad || x->grad || c->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
|
||||
|
||||
result->op = GGML_OP_SSM_CONV;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = s;
|
||||
result->src[1] = x;
|
||||
result->src[2] = c;
|
||||
result->src[3] = sq;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -7150,40 +7145,43 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * sq) {
|
||||
struct ggml_tensor * C) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(x));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
GGML_ASSERT(ggml_is_3d(B));
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
const int64_t d_inner = s->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
const int64_t n_seq_tokens = x->ne[1];
|
||||
const int64_t n_seqs = x->ne[2];
|
||||
|
||||
GGML_ASSERT(s->ne[2] == n_seqs);
|
||||
GGML_ASSERT(x->ne[0] == d_inner);
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == d_inner);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
||||
GGML_ASSERT(C->ne[0] == d_state);
|
||||
GGML_ASSERT(C->ne[1] == n_tokens);
|
||||
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[2] == n_seqs);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
// y
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]);
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@ -7193,7 +7191,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
result->src[3] = A;
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = sq;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -16249,24 +16246,20 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src2->ne[0]; // d_conv
|
||||
const int nr = src0->ne[1]; // d_inner
|
||||
const int n_t = src1->ne[1]; // n_tokens
|
||||
const int n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||
const int n_t = src1->ne[1]; // tokens per sequence
|
||||
const int n_s = src0->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// for use with the destination state offset between sequences
|
||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
@ -16276,51 +16269,31 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
const int32_t * sq = src3->data; // {n_tokens}
|
||||
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
|
||||
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
|
||||
float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
if (n_rs > 1) {
|
||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
||||
// so copy them all over to the destination, just to be sure.
|
||||
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
||||
// can't use memcpy because of d_conv vs d_conv - 1
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
||||
|
||||
// copy the state into working memory
|
||||
// can't use memcpy because (d_conv) != (d_conv - 1)
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
// copy s0 to last (d_conv - 1) columns of s
|
||||
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t sq_i = sq[i2];
|
||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
|
||||
float * s0; // {d_conv - 1, d_inner, n_rs}
|
||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
||||
int ne0s0;
|
||||
|
||||
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
|
||||
ne0s0 = src0->ne[0];
|
||||
} else {
|
||||
// the source is the last (d_conv - 1) columns of the destination
|
||||
s0 = s + 1;
|
||||
ne0s0 = nc;
|
||||
}
|
||||
// shift state left
|
||||
memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// shift state left
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
||||
}
|
||||
// insert x on the last column
|
||||
s[(nc - 1) + i1*nc] = x0[i1];
|
||||
}
|
||||
@ -16328,6 +16301,7 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
// it seems a little faster when this is separate from the state shift
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// rowwise dot product
|
||||
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
||||
float sumf = 0.0f;
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
@ -16336,6 +16310,14 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
x[i1] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
// copy the state out of it
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_ssm_conv(
|
||||
@ -16368,30 +16350,24 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||
const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
// required for the dot product between s and C, and when copying the states
|
||||
// required for the dot product between s and C
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// required for per-sequence offsets for states
|
||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||
// required to get correct offset for state destination (i.e. src1->nb[2])
|
||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
@ -16401,38 +16377,15 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
const int32_t * sq = src6->data; // {n_tokens}
|
||||
|
||||
if (n_rs > 1) {
|
||||
// it's hard to know if the source states have already been copied
|
||||
// when there are multiple, so copy them already.
|
||||
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
||||
memcpy(s, s0, nc*ir*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t sq_i = sq[i2];
|
||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
|
||||
float * s0;
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||
float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||
|
||||
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
|
||||
} else {
|
||||
// otherwise the source is the same as the destination
|
||||
s0 = s;
|
||||
}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
@ -16444,7 +16397,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
@ -16453,6 +16406,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_ssm_scan(
|
||||
const struct ggml_compute_params * params,
|
||||
@ -19614,7 +19568,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
const int64_t d_conv = node->src[2]->ne[0];
|
||||
const int64_t d_inner = node->src[0]->ne[1];
|
||||
|
||||
cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1);
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
||||
|
6
ggml.h
6
ggml.h
@ -1793,8 +1793,7 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq);
|
||||
struct ggml_tensor * c);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
@ -1803,8 +1802,7 @@ extern "C" {
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * sq);
|
||||
struct ggml_tensor * C);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
|
Loading…
Reference in New Issue
Block a user