mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
parent
e84b71c2c6
commit
d48c88cbd5
@ -643,7 +643,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||||||
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
|
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
|
||||||
struct ggml_tensor * t16;
|
struct ggml_tensor * t16;
|
||||||
if (enable_flash_attn) {
|
if (enable_flash_attn) {
|
||||||
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
|
GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
|
||||||
|
//t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
|
||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
|
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
|
||||||
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
|
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
|
||||||
|
@ -341,7 +341,8 @@ static struct ggml_tensor * llama_build_train_graphs(
|
|||||||
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
||||||
struct ggml_tensor * t16;
|
struct ggml_tensor * t16;
|
||||||
if (enable_flash_attn) {
|
if (enable_flash_attn) {
|
||||||
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
|
||||||
|
//t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
|
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
|
||||||
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
|
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
|
||||||
|
676
ggml.c
676
ggml.c
@ -2670,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"ARGSORT",
|
"ARGSORT",
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
|
|
||||||
"FLASH_ATTN",
|
|
||||||
"FLASH_ATTN_EXT",
|
"FLASH_ATTN_EXT",
|
||||||
"FLASH_FF",
|
|
||||||
"FLASH_ATTN_BACK",
|
"FLASH_ATTN_BACK",
|
||||||
"SSM_CONV",
|
"SSM_CONV",
|
||||||
"SSM_SCAN",
|
"SSM_SCAN",
|
||||||
@ -2698,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
@ -2760,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"argsort(x)",
|
"argsort(x)",
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
|
|
||||||
"flash_attn(x)",
|
|
||||||
"flash_attn_ext(x)",
|
"flash_attn_ext(x)",
|
||||||
"flash_ff(x)",
|
|
||||||
"flash_attn_back(x)",
|
"flash_attn_back(x)",
|
||||||
"ssm_conv(x)",
|
"ssm_conv(x)",
|
||||||
"ssm_scan(x)",
|
"ssm_scan(x)",
|
||||||
@ -2788,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
@ -6948,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_attn
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * q,
|
|
||||||
struct ggml_tensor * k,
|
|
||||||
struct ggml_tensor * v,
|
|
||||||
bool masked) {
|
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
|
||||||
|
|
||||||
bool is_node = false;
|
|
||||||
|
|
||||||
if (q->grad || k->grad || v->grad) {
|
|
||||||
is_node = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
|
||||||
|
|
||||||
int32_t t = masked ? 1 : 0;
|
|
||||||
ggml_set_op_params(result, &t, sizeof(t));
|
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN;
|
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
||||||
result->src[0] = q;
|
|
||||||
result->src[1] = k;
|
|
||||||
result->src[2] = v;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_flash_attn_ext
|
// ggml_flash_attn_ext
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_ext(
|
struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
@ -7039,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
|
|||||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_ff
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_ff(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b0,
|
|
||||||
struct ggml_tensor * b1,
|
|
||||||
struct ggml_tensor * c0,
|
|
||||||
struct ggml_tensor * c1) {
|
|
||||||
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
|
||||||
// TODO: more checks
|
|
||||||
|
|
||||||
bool is_node = false;
|
|
||||||
|
|
||||||
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
|
||||||
is_node = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_FF;
|
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
||||||
result->src[0] = a;
|
|
||||||
result->src[1] = b0;
|
|
||||||
result->src[2] = b1;
|
|
||||||
result->src[3] = c0;
|
|
||||||
result->src[4] = c1;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_back(
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
@ -7080,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * d,
|
struct ggml_tensor * d,
|
||||||
bool masked) {
|
bool masked) {
|
||||||
|
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
||||||
|
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
@ -15709,400 +15643,6 @@ static void ggml_compute_forward_argsort(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_attn
|
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_f32(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
const bool masked,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * q = dst->src[0];
|
|
||||||
const struct ggml_tensor * k = dst->src[1];
|
|
||||||
const struct ggml_tensor * v = dst->src[2];
|
|
||||||
|
|
||||||
int64_t t0 = ggml_perf_time_us();
|
|
||||||
UNUSED(t0);
|
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
|
||||||
const int nth = params->nth;
|
|
||||||
|
|
||||||
const int64_t D = neq0;
|
|
||||||
const int64_t N = neq1;
|
|
||||||
const int64_t P = nek1 - N;
|
|
||||||
const int64_t M = P + N;
|
|
||||||
|
|
||||||
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == D);
|
|
||||||
GGML_ASSERT(ne1 == N);
|
|
||||||
GGML_ASSERT(P >= 0);
|
|
||||||
|
|
||||||
GGML_ASSERT(nbq0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nbk0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nbv0 == sizeof(float));
|
|
||||||
|
|
||||||
GGML_ASSERT(neq0 == D);
|
|
||||||
GGML_ASSERT(nek0 == D);
|
|
||||||
GGML_ASSERT(nev1 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(neq1 == N);
|
|
||||||
GGML_ASSERT(nek1 == N + P);
|
|
||||||
GGML_ASSERT(nev1 == D);
|
|
||||||
|
|
||||||
// dst cannot be transposed or permuted
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb0 <= nb1);
|
|
||||||
GGML_ASSERT(nb1 <= nb2);
|
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// parallelize by q rows using ggml_vec_dot_f32
|
|
||||||
|
|
||||||
// total rows in q
|
|
||||||
const int nr = neq1*neq2*neq3;
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrtf(D);
|
|
||||||
|
|
||||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
|
||||||
// q indices
|
|
||||||
const int iq3 = ir/(neq2*neq1);
|
|
||||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
||||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
||||||
|
|
||||||
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
|
||||||
|
|
||||||
for (int i = M; i < Mup; ++i) {
|
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
|
||||||
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
|
||||||
// k indices
|
|
||||||
const int ik3 = iq3;
|
|
||||||
const int ik2 = iq2 % nek2;
|
|
||||||
const int ik1 = ic;
|
|
||||||
|
|
||||||
// S indices
|
|
||||||
const int i1 = ik1;
|
|
||||||
|
|
||||||
ggml_vec_dot_f32(neq0,
|
|
||||||
S + i1, 0,
|
|
||||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
|
||||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// scale
|
|
||||||
ggml_vec_scale_f32(masked_begin, S, scale);
|
|
||||||
|
|
||||||
for (int64_t i = masked_begin; i < M; i++) {
|
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
// softmax
|
|
||||||
// exclude known -INF S[..] values from max and loop
|
|
||||||
// dont forget to set their SW values to zero
|
|
||||||
{
|
|
||||||
float max = -INFINITY;
|
|
||||||
ggml_vec_max_f32(masked_begin, &max, S);
|
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
|
||||||
{
|
|
||||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
|
||||||
max = -max;
|
|
||||||
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
|
||||||
vvexpf(S, S, &Mup);
|
|
||||||
ggml_vec_sum_f32(Mup, &sum, S);
|
|
||||||
#else
|
|
||||||
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sum > 0.0);
|
|
||||||
|
|
||||||
sum = 1.0/sum;
|
|
||||||
ggml_vec_scale_f32(masked_begin, S, sum);
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
for (int i = 0; i < masked_begin; ++i) {
|
|
||||||
assert(!isnan(S[i]));
|
|
||||||
assert(!isinf(S[i]));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < nev1; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
// v indices
|
|
||||||
const int iv2 = iq2 % nev2;
|
|
||||||
const int iv3 = iq3;
|
|
||||||
|
|
||||||
ggml_vec_dot_f32(masked_begin,
|
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
|
||||||
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
|
||||||
S, 0, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_f16(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
const bool masked,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * q = dst->src[0];
|
|
||||||
const struct ggml_tensor * k = dst->src[1];
|
|
||||||
const struct ggml_tensor * v = dst->src[2];
|
|
||||||
|
|
||||||
int64_t t0 = ggml_perf_time_us();
|
|
||||||
UNUSED(t0);
|
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
|
||||||
const int nth = params->nth;
|
|
||||||
|
|
||||||
const int64_t D = neq0;
|
|
||||||
const int64_t N = neq1;
|
|
||||||
const int64_t P = nek1 - N;
|
|
||||||
const int64_t M = P + N;
|
|
||||||
|
|
||||||
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == D);
|
|
||||||
GGML_ASSERT(ne1 == N);
|
|
||||||
GGML_ASSERT(P >= 0);
|
|
||||||
|
|
||||||
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
|
||||||
|
|
||||||
GGML_ASSERT(neq0 == D);
|
|
||||||
GGML_ASSERT(nek0 == D);
|
|
||||||
GGML_ASSERT(nev1 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(neq1 == N);
|
|
||||||
GGML_ASSERT(nek1 == N + P);
|
|
||||||
GGML_ASSERT(nev1 == D);
|
|
||||||
|
|
||||||
// dst cannot be transposed or permuted
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb0 <= nb1);
|
|
||||||
GGML_ASSERT(nb1 <= nb2);
|
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// parallelize by q rows using ggml_vec_dot_f32
|
|
||||||
|
|
||||||
// total rows in q
|
|
||||||
const int nr = neq1*neq2*neq3;
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrtf(D);
|
|
||||||
|
|
||||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
|
||||||
// q indices
|
|
||||||
const int iq3 = ir/(neq2*neq1);
|
|
||||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
||||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
||||||
|
|
||||||
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
|
||||||
|
|
||||||
for (int i = M; i < Mup; ++i) {
|
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
|
||||||
// k indices
|
|
||||||
const int ik3 = iq3;
|
|
||||||
const int ik2 = iq2 % nek2;
|
|
||||||
const int ik1 = ic;
|
|
||||||
|
|
||||||
// S indices
|
|
||||||
const int i1 = ik1;
|
|
||||||
|
|
||||||
ggml_vec_dot_f16(neq0,
|
|
||||||
S + i1, 0,
|
|
||||||
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
|
||||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
|
||||||
// k indices
|
|
||||||
const int ik3 = iq3;
|
|
||||||
const int ik2 = iq2 % nek2;
|
|
||||||
const int ik1 = ic;
|
|
||||||
|
|
||||||
// S indices
|
|
||||||
const int i1 = ik1;
|
|
||||||
|
|
||||||
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
|
||||||
S + i1,
|
|
||||||
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
||||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// scale
|
|
||||||
ggml_vec_scale_f32(nek1, S, scale);
|
|
||||||
|
|
||||||
if (masked) {
|
|
||||||
for (int64_t i = P; i < M; i++) {
|
|
||||||
if (i > P + iq1) {
|
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// softmax
|
|
||||||
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
|
||||||
// dont forget to set their S values to zero
|
|
||||||
{
|
|
||||||
float max = -INFINITY;
|
|
||||||
ggml_vec_max_f32(M, &max, S);
|
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
|
||||||
{
|
|
||||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
|
||||||
max = -max;
|
|
||||||
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
|
||||||
vvexpf(S, S, &Mup);
|
|
||||||
ggml_vec_sum_f32(Mup, &sum, S);
|
|
||||||
#else
|
|
||||||
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sum > 0.0);
|
|
||||||
|
|
||||||
sum = 1.0/sum;
|
|
||||||
ggml_vec_scale_f32(M, S, sum);
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
for (int i = 0; i < M; ++i) {
|
|
||||||
assert(!isnan(S[i]));
|
|
||||||
assert(!isinf(S[i]));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < M; i++) {
|
|
||||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
|
||||||
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
|
||||||
for (int64_t ic = 0; ic < nev1; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
// v indices
|
|
||||||
const int iv2 = iq2 % nev2;
|
|
||||||
const int iv3 = iq3;
|
|
||||||
|
|
||||||
ggml_vec_dot_f16(nev0,
|
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
|
||||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
|
||||||
S16, 0, 1);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
// v indices
|
|
||||||
const int iv2 = iq2 % nev2;
|
|
||||||
const int iv3 = iq3;
|
|
||||||
|
|
||||||
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
|
||||||
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
|
||||||
S16);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
const bool masked,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * q = dst->src[0];
|
|
||||||
|
|
||||||
switch (q->type) {
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
|
||||||
} break;
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_compute_forward_flash_attn_ext
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
@ -16336,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_ff
|
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_ff_f16(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * a = dst->src[0]; // F16
|
|
||||||
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
|
||||||
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
|
||||||
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
|
||||||
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
|
||||||
|
|
||||||
int64_t t0 = ggml_perf_time_us();
|
|
||||||
UNUSED(t0);
|
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
|
||||||
const int nth = params->nth;
|
|
||||||
|
|
||||||
const int64_t D = nea0;
|
|
||||||
//const int64_t N = nea1;
|
|
||||||
const int64_t M = neb01;
|
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == nea0);
|
|
||||||
GGML_ASSERT(ne1 == nea1);
|
|
||||||
GGML_ASSERT(ne2 == nea2);
|
|
||||||
|
|
||||||
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nbb10 == sizeof(float));
|
|
||||||
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nbc10 == sizeof(float));
|
|
||||||
|
|
||||||
GGML_ASSERT(neb00 == D);
|
|
||||||
GGML_ASSERT(neb01 == M);
|
|
||||||
GGML_ASSERT(neb10 == M);
|
|
||||||
GGML_ASSERT(neb11 == 1);
|
|
||||||
|
|
||||||
GGML_ASSERT(nec00 == M);
|
|
||||||
GGML_ASSERT(nec01 == D);
|
|
||||||
GGML_ASSERT(nec10 == D);
|
|
||||||
GGML_ASSERT(nec11 == 1);
|
|
||||||
|
|
||||||
// dst cannot be transposed or permuted
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb0 <= nb1);
|
|
||||||
GGML_ASSERT(nb1 <= nb2);
|
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// parallelize by a rows using ggml_vec_dot_f32
|
|
||||||
|
|
||||||
// total rows in a
|
|
||||||
const int nr = nea1*nea2*nea3;
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
|
||||||
// a indices
|
|
||||||
const int ia3 = ir/(nea2*nea1);
|
|
||||||
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
|
||||||
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
|
||||||
|
|
||||||
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < neb01; ++ic) {
|
|
||||||
// b0 indices
|
|
||||||
const int ib03 = ia3;
|
|
||||||
const int ib02 = ia2;
|
|
||||||
const int ib01 = ic;
|
|
||||||
|
|
||||||
// S indices
|
|
||||||
const int i1 = ib01;
|
|
||||||
|
|
||||||
ggml_vec_dot_f16(nea0,
|
|
||||||
S + i1, 0,
|
|
||||||
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
|
||||||
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
|
||||||
//ggml_vec_gelu_f32(neb01, S, S);
|
|
||||||
|
|
||||||
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < M; i++) {
|
|
||||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_vec_gelu_f16(neb01, S16, S16);
|
|
||||||
|
|
||||||
{
|
|
||||||
// dst indices
|
|
||||||
const int i1 = ia1;
|
|
||||||
const int i2 = ia2;
|
|
||||||
const int i3 = ia3;
|
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < nec01; ++ic) {
|
|
||||||
|
|
||||||
ggml_vec_dot_f16(neb01,
|
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
|
||||||
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
|
||||||
S16, 0, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_vec_add_f32(nec01,
|
|
||||||
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
|
||||||
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
|
||||||
(float *) c1->data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_ff(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * b0 = dst->src[1];
|
|
||||||
|
|
||||||
switch (b0->type) {
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_flash_ff_f16(params, dst);
|
|
||||||
} break;
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false); // TODO
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_compute_forward_flash_attn_back
|
// ggml_compute_forward_flash_attn_back
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_back_f32(
|
static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
@ -18065,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_leaky_relu(params, tensor);
|
ggml_compute_forward_leaky_relu(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
|
||||||
{
|
|
||||||
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
|
||||||
GGML_ASSERT(t == 0 || t == 1);
|
|
||||||
const bool masked = t != 0;
|
|
||||||
ggml_compute_forward_flash_attn(params, masked, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_flash_ff(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||||
@ -19086,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * flash_grad = NULL;
|
struct ggml_tensor * flash_grad = NULL;
|
||||||
@ -19140,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false); // not supported
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // not supported
|
GGML_ASSERT(false); // not supported
|
||||||
@ -19830,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
@ -20235,40 +19595,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||||||
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
||||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
|
||||||
{
|
|
||||||
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
|
||||||
|
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
|
||||||
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
|
||||||
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
|
||||||
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
|
||||||
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
const int64_t ne00 = node->src[0]->ne[0]; // D
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||||
|
|
||||||
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
|
||||||
{
|
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
|
||||||
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
|
||||||
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
|
||||||
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
|
||||||
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
|
||||||
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
const int64_t D = node->src[0]->ne[0];
|
const int64_t D = node->src[0]->ne[0];
|
||||||
|
18
ggml.h
18
ggml.h
@ -481,9 +481,7 @@ extern "C" {
|
|||||||
GGML_OP_ARGSORT,
|
GGML_OP_ARGSORT,
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
|
||||||
GGML_OP_FLASH_ATTN_EXT,
|
GGML_OP_FLASH_ATTN_EXT,
|
||||||
GGML_OP_FLASH_FF,
|
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
GGML_OP_SSM_CONV,
|
GGML_OP_SSM_CONV,
|
||||||
GGML_OP_SSM_SCAN,
|
GGML_OP_SSM_SCAN,
|
||||||
@ -1761,13 +1759,6 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int k);
|
int k);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * q,
|
|
||||||
struct ggml_tensor * k,
|
|
||||||
struct ggml_tensor * v,
|
|
||||||
bool masked);
|
|
||||||
|
|
||||||
#define GGML_KQ_MASK_PAD 32
|
#define GGML_KQ_MASK_PAD 32
|
||||||
|
|
||||||
// q: [n_embd, n_batch, n_head, 1]
|
// q: [n_embd, n_batch, n_head, 1]
|
||||||
@ -1788,6 +1779,7 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec);
|
enum ggml_prec prec);
|
||||||
|
|
||||||
|
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
@ -1796,14 +1788,6 @@ extern "C" {
|
|||||||
struct ggml_tensor * d,
|
struct ggml_tensor * d,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_ff(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b0,
|
|
||||||
struct ggml_tensor * b1,
|
|
||||||
struct ggml_tensor * c0,
|
|
||||||
struct ggml_tensor * c1);
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * s,
|
||||||
|
@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// flash_attn f32
|
// flash_attn f32
|
||||||
{
|
// TODO: adapt to ggml_flash_attn_ext() changes
|
||||||
srand(seed);
|
//{
|
||||||
const int nargs = 3;
|
// srand(seed);
|
||||||
|
// const int nargs = 3;
|
||||||
|
|
||||||
int64_t ne2[4];
|
// int64_t ne2[4];
|
||||||
|
|
||||||
get_random_dims(ne2, 4);
|
// get_random_dims(ne2, 4);
|
||||||
int64_t D = ne2[0];
|
// int64_t D = ne2[0];
|
||||||
int64_t N = ne2[1];
|
// int64_t N = ne2[1];
|
||||||
int64_t M = ne2[2] + N;
|
// int64_t M = ne2[2] + N;
|
||||||
int64_t B = ne2[3];
|
// int64_t B = ne2[3];
|
||||||
|
|
||||||
for (int masked = 0; masked <= 1; ++masked) {
|
// for (int masked = 0; masked <= 1; ++masked) {
|
||||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
// for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
int max_nrep = (ndims >= 3) ? 2 : 1;
|
// int max_nrep = (ndims >= 3) ? 2 : 1;
|
||||||
for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
// for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
||||||
int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
// int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
||||||
int64_t nek[4] = { D, M, B, ne[3] };
|
// int64_t nek[4] = { D, M, B, ne[3] };
|
||||||
int64_t nev[4] = { M, D, B, ne[3] };
|
// int64_t nev[4] = { M, D, B, ne[3] };
|
||||||
if (ndims == 2) {
|
// if (ndims == 2) {
|
||||||
neq[2] = 1; neq[3] = 1;
|
// neq[2] = 1; neq[3] = 1;
|
||||||
nek[2] = 1; nek[3] = 1;
|
// nek[2] = 1; nek[3] = 1;
|
||||||
nev[2] = 1; nev[3] = 1;
|
// nev[2] = 1; nev[3] = 1;
|
||||||
} else if (ndims == 3) {
|
// } else if (ndims == 3) {
|
||||||
neq[3] = 1;
|
// neq[3] = 1;
|
||||||
nek[3] = 1;
|
// nek[3] = 1;
|
||||||
nev[3] = 1;
|
// nev[3] = 1;
|
||||||
}
|
// }
|
||||||
x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
// x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||||
x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
// x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||||
x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
// x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
// ggml_set_param(ctx0, x[0]);
|
||||||
ggml_set_param(ctx0, x[1]);
|
// ggml_set_param(ctx0, x[1]);
|
||||||
ggml_set_param(ctx0, x[2]);
|
// ggml_set_param(ctx0, x[2]);
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
// struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||||
|
|
||||||
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
// check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
// flash_attn f16, not yet fully implemented
|
|
||||||
if(0)
|
|
||||||
{
|
|
||||||
srand(seed);
|
|
||||||
const int nargs = 3;
|
|
||||||
|
|
||||||
int64_t ne2[4];
|
|
||||||
|
|
||||||
get_random_dims(ne2, 4);
|
|
||||||
int64_t D = ne2[0];
|
|
||||||
int64_t N = ne2[1];
|
|
||||||
int64_t M = ne2[2] + N;
|
|
||||||
int64_t B = ne2[3];
|
|
||||||
|
|
||||||
for (int masked = 0; masked <= 1; ++masked) {
|
|
||||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
|
||||||
int64_t neq[4] = { D, N, B, ne[3] };
|
|
||||||
int64_t nek[4] = { D, M, B, ne[3] };
|
|
||||||
int64_t nev[4] = { M, D, B, ne[3] };
|
|
||||||
if (ndims == 2) {
|
|
||||||
neq[2] = 1; neq[3] = 1;
|
|
||||||
nek[2] = 1; nek[3] = 1;
|
|
||||||
nev[2] = 1; nev[3] = 1;
|
|
||||||
} else if (ndims == 3) {
|
|
||||||
neq[3] = 1;
|
|
||||||
nek[3] = 1;
|
|
||||||
nev[3] = 1;
|
|
||||||
}
|
|
||||||
x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
|
||||||
x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
|
||||||
x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
|
||||||
ggml_set_param(ctx0, x[1]);
|
|
||||||
ggml_set_param(ctx0, x[2]);
|
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
|
||||||
|
|
||||||
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user