mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
llama : use im2col and mul_mat to perform convolution for Mamba
This removes the need for ggml_ssm_conv!!! But performance seems slighly worse on my system, especially for prompt processing. Maybe ggml_mul_mat isn't optimized for small row sizes? More performance testing is necessary until GGML_OP_SSM_CONV is removed. * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model.
This commit is contained in:
parent
eb589d5e36
commit
8fb57ac0fb
121
ggml.c
121
ggml.c
@ -7124,26 +7124,24 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|||||||
|
|
||||||
struct ggml_tensor * ggml_ssm_conv(
|
struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
|
||||||
struct ggml_tensor * c) {
|
struct ggml_tensor * c) {
|
||||||
GGML_ASSERT(ggml_is_3d(s));
|
GGML_ASSERT(ggml_is_3d(sx));
|
||||||
GGML_ASSERT(ggml_is_3d(x));
|
|
||||||
GGML_ASSERT(ggml_is_matrix(c));
|
GGML_ASSERT(ggml_is_matrix(c));
|
||||||
|
|
||||||
const int64_t d_conv = c->ne[0];
|
const int64_t d_conv = c->ne[0];
|
||||||
const int64_t d_inner = c->ne[1];
|
const int64_t d_inner = c->ne[1];
|
||||||
const int64_t n_t = x->ne[1]; // tokens per sequence
|
const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
|
||||||
const int64_t n_s = s->ne[2];
|
const int64_t n_s = sx->ne[2];
|
||||||
|
|
||||||
GGML_ASSERT(s->ne[0] == d_conv - 1);
|
// TODO: maybe support other strides than 1?
|
||||||
GGML_ASSERT(s->ne[1] == d_inner);
|
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||||
GGML_ASSERT(x->ne[0] == d_inner);
|
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||||
GGML_ASSERT(x->ne[2] == n_s);
|
GGML_ASSERT(n_t >= 0);
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || c->grad) {
|
if (sx->grad || c->grad) {
|
||||||
GGML_ASSERT(false); // TODO: implement
|
GGML_ASSERT(false); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
@ -7152,9 +7150,8 @@ struct ggml_tensor * ggml_ssm_conv(
|
|||||||
|
|
||||||
result->op = GGML_OP_SSM_CONV;
|
result->op = GGML_OP_SSM_CONV;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = s;
|
result->src[0] = sx;
|
||||||
result->src[1] = x;
|
result->src[1] = c;
|
||||||
result->src[2] = c;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -7203,8 +7200,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// y
|
// concatenated y + ssm_states
|
||||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]);
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
@ -16252,22 +16249,21 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int nc = src2->ne[0]; // d_conv
|
const int nc = src1->ne[0]; // d_conv
|
||||||
|
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||||
const int nr = src0->ne[1]; // d_inner
|
const int nr = src0->ne[1]; // d_inner
|
||||||
const int n_t = src1->ne[1]; // tokens per sequence
|
const int n_t = dst->ne[1]; // tokens per sequence
|
||||||
const int n_s = src0->ne[2]; // number of sequences in the batch
|
const int n_s = dst->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
GGML_ASSERT( dst->ne[0] == nr);
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
@ -16278,54 +16274,28 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
|
|
||||||
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
|
|
||||||
float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
|
|
||||||
|
|
||||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
|
||||||
|
|
||||||
// copy the state into working memory
|
|
||||||
// can't use memcpy because (d_conv) != (d_conv - 1)
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
||||||
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s}
|
// {d_conv - 1 + n_t, d_inner, n_seqs}
|
||||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
// sliding window
|
||||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
||||||
|
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
|
||||||
// shift state left
|
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
|
||||||
memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
|
|
||||||
|
|
||||||
|
// TODO: transpose the output for smaller strides for big batches?
|
||||||
// d_inner
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
// insert x on the last column
|
|
||||||
s[(nc - 1) + i1*nc] = x0[i1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// it seems a little faster when this is separate from the state shift
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
// rowwise dot product
|
// rowwise dot product
|
||||||
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
// d_conv
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
|
||||||
sumf += s[i] * c[i];
|
|
||||||
}
|
}
|
||||||
x[i1] = sumf;
|
x[i1] = sumf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy the state out of it
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
||||||
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -16368,7 +16338,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst));
|
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
@ -16377,6 +16347,10 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C
|
// required for the dot product between s and C
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
|
// required for per-sequence offsets for states
|
||||||
|
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||||
|
// required to get correct offset for state destination (i.e. src1->nb[3])
|
||||||
|
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
@ -16388,13 +16362,17 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
|
|
||||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s}
|
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
||||||
|
|
||||||
|
// use the output as the source for the next token-wise iterations
|
||||||
|
if (i2 > 0) { s0 = s; }
|
||||||
|
|
||||||
// d_inner
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
@ -16406,7 +16384,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
int i = i0 + i1*nc;
|
||||||
// state = prev_state * dA + dB * x
|
// state = prev_state * dA + dB * x
|
||||||
float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
// y = rowwise_dotprod(state, C)
|
// y = rowwise_dotprod(state, C)
|
||||||
sumf += state * C[i0];
|
sumf += state * C[i0];
|
||||||
s[i] = state;
|
s[i] = state;
|
||||||
@ -19577,13 +19555,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||||||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
{
|
|
||||||
const int64_t d_conv = node->src[2]->ne[0];
|
|
||||||
const int64_t d_inner = node->src[0]->ne[1];
|
|
||||||
|
|
||||||
cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
{
|
{
|
||||||
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
||||||
|
3
ggml.h
3
ggml.h
@ -1803,8 +1803,7 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
|
||||||
struct ggml_tensor * c);
|
struct ggml_tensor * c);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||||
|
79
llama.cpp
79
llama.cpp
@ -2827,11 +2827,13 @@ struct llama_rs_cache {
|
|||||||
n_shared_tail_cells += 1;
|
n_shared_tail_cells += 1;
|
||||||
n_seqs -= 1;
|
n_seqs -= 1;
|
||||||
}
|
}
|
||||||
} else if (rs_cell.is_empty()) {
|
} else {
|
||||||
|
if (rs_cell.is_empty()) {
|
||||||
// from shared to unique
|
// from shared to unique
|
||||||
n_seqs += 1;
|
n_seqs += 1;
|
||||||
if (prev_cell.tail_rc == 1) {
|
}
|
||||||
// it was the last tail of the previous cell
|
if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) {
|
||||||
|
// from last shared to fully tail
|
||||||
n_shared_tail_cells -= 1;
|
n_shared_tail_cells -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -8683,6 +8685,18 @@ static struct ggml_tensor * llm_build_mamba(
|
|||||||
conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs);
|
conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs);
|
||||||
ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs);
|
ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs);
|
||||||
|
|
||||||
|
// copy states which won't be changed further (between n_seqs and n_rs)
|
||||||
|
ggml_build_forward_expand(graph,
|
||||||
|
ggml_cpy(ctx,
|
||||||
|
ggml_view_1d(ctx, conv_states, (d_conv - 1)*d_inner*(n_rs - n_seqs), n_seqs*(conv_states->nb[2])),
|
||||||
|
ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*(d_conv - 1)*d_inner*ggml_element_size(conv_states_all))));
|
||||||
|
|
||||||
|
ggml_build_forward_expand(graph,
|
||||||
|
ggml_cpy(ctx,
|
||||||
|
ggml_view_1d(ctx, ssm_states, d_state*d_inner*(n_rs - n_seqs), n_seqs*(ssm_states->nb[2])),
|
||||||
|
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
||||||
|
|
||||||
|
// the part of the states that will be used and modified
|
||||||
struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0);
|
struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0);
|
||||||
struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0);
|
struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0);
|
||||||
|
|
||||||
@ -8698,28 +8712,43 @@ static struct ggml_tensor * llm_build_mamba(
|
|||||||
|
|
||||||
// conv
|
// conv
|
||||||
{
|
{
|
||||||
// Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml.
|
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
||||||
// And also because this uses much less memory for large batches (4 times less when d_conv is 4).
|
struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0);
|
||||||
// The equivalent is to concatenate the columns of conv_states and x,
|
|
||||||
// then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
|
// copy last (d_conv - 1) columns back into the state cache
|
||||||
// then element-wise multiply that with the conv1d weigth,
|
struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
|
||||||
|
|
||||||
|
ggml_build_forward_expand(graph,
|
||||||
|
ggml_cpy(ctx, last_conv,
|
||||||
|
ggml_view_1d(ctx, conv_states_all,
|
||||||
|
(d_conv - 1)*(d_inner)*(n_seqs),
|
||||||
|
rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
|
||||||
|
|
||||||
|
// 1D convolution
|
||||||
|
// The equivalent is to make a self-overlapping view of conv_x
|
||||||
|
// over d_conv columns at each stride in the 3rd dimension,
|
||||||
|
// then element-wise multiply that with the conv1d weight,
|
||||||
// then sum the elements of each row,
|
// then sum the elements of each row,
|
||||||
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
||||||
// then permute away the ne[0] dimension,
|
// then permute away the ne[0] dimension,
|
||||||
// and then you're left with the resulting x tensor.
|
// and then you're left with the resulting x tensor.
|
||||||
// The new conv_states is the last (d_conv - 1) columns
|
|
||||||
// of the last 3rd dimensional "layer" of the self-overlapping view.
|
|
||||||
// For simultaneous sequences, all sequences need to have the same length.
|
// For simultaneous sequences, all sequences need to have the same length.
|
||||||
x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d);
|
|
||||||
|
|
||||||
// ensure conv is updated before copying into the recurrent state cache
|
// For some reason, im2col expects a F16 kernel, but doesn't even read from it.
|
||||||
ggml_build_forward_expand(graph, x);
|
// TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it.
|
||||||
|
// => { d_conv * d_inner, n_seq_tokens, n_seqs}
|
||||||
|
x = ggml_im2col(ctx,
|
||||||
|
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner),
|
||||||
|
conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32);
|
||||||
|
|
||||||
ggml_build_forward_expand(graph,
|
x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs);
|
||||||
ggml_cpy(ctx, conv_states,
|
|
||||||
ggml_view_1d(ctx, conv_states_all,
|
// => {1, 1, d_inner, n_seq_tokens * n_seqs}
|
||||||
(d_conv - 1)*(d_inner)*(n_rs),
|
x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x);
|
||||||
rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
|
x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
|
// Alternatively, this does the same as the above
|
||||||
|
// x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
|
x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
|
||||||
@ -8746,16 +8775,16 @@ static struct ggml_tensor * llm_build_mamba(
|
|||||||
|
|
||||||
// Custom operator to optimize the parallel associative scan
|
// Custom operator to optimize the parallel associative scan
|
||||||
// as described in the Annex D of the Mamba paper.
|
// as described in the Annex D of the Mamba paper.
|
||||||
// => {d_inner, n_seq_tokens, n_seqs}
|
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||||
struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
|
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
|
||||||
|
|
||||||
// The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache
|
|
||||||
ggml_build_forward_expand(graph, y);
|
|
||||||
|
|
||||||
// store last states
|
// store last states
|
||||||
ggml_build_forward_expand(graph,
|
ggml_build_forward_expand(graph,
|
||||||
ggml_cpy(ctx, ssm_states,
|
ggml_cpy(ctx,
|
||||||
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
|
||||||
|
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
||||||
|
|
||||||
|
struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
|
||||||
|
|
||||||
// TODO: skip computing output earlier for unused tokens
|
// TODO: skip computing output earlier for unused tokens
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user