diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 992426c1b..22425730f 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -643,7 +643,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch); struct ggml_tensor * t16; if (enable_flash_attn) { - t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch); + GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported"); + //t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch); } else { struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch); struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 45bdfa8f5..e2f85c682 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -341,7 +341,8 @@ static struct ggml_tensor * llama_build_train_graphs( struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); struct ggml_tensor * t16; if (enable_flash_attn) { - t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported"); + //t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); } else { struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch); struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch); diff --git a/ggml.c b/ggml.c index 673c47748..9e72b7a76 100644 --- a/ggml.c +++ b/ggml.c @@ -2670,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGSORT", "LEAKY_RELU", - "FLASH_ATTN", "FLASH_ATTN_EXT", - "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", "SSM_SCAN", @@ -2698,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2760,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argsort(x)", "leaky_relu(x)", - "flash_attn(x)", "flash_attn_ext(x)", - "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", "ssm_scan(x)", @@ -2788,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6948,38 +6944,6 @@ struct ggml_tensor * ggml_top_k( return result; } -// ggml_flash_attn - -struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked) { - GGML_ASSERT(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); - - int32_t t = masked ? 1 : 0; - ggml_set_op_params(result, &t, sizeof(t)); - - result->op = GGML_OP_FLASH_ATTN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - - return result; -} - // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -7039,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } -// ggml_flash_ff - -struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1) { - GGML_ASSERT(ggml_can_mul_mat(b0, a)); - // TODO: more checks - - bool is_node = false; - - if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); - - result->op = GGML_OP_FLASH_FF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b0; - result->src[2] = b1; - result->src[3] = c0; - result->src[4] = c1; - - return result; -} - // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -7080,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * v, struct ggml_tensor * d, bool masked) { + GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes"); + GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -15709,400 +15643,6 @@ static void ggml_compute_forward_argsort( } } -// ggml_compute_forward_flash_attn - -static void ggml_compute_forward_flash_attn_f32( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - const struct ggml_tensor * k = dst->src[1]; - const struct ggml_tensor * v = dst->src[2]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, 0, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - - // scale - ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SW values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(masked_begin, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - sum = ggml_vec_soft_max_f32(Mup, S, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(masked_begin, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < masked_begin; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f32(masked_begin, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0, - S, 0, 1); - } - } -} - -static void ggml_compute_forward_flash_attn_f16( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - const struct ggml_tensor * k = dst->src[1]; - const struct ggml_tensor * v = dst->src[2]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16(neq0, - S + i1, 0, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - sum = ggml_vec_soft_max_f32(Mup, S, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0, - S16, 0, 1); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } - } -} - -static void ggml_compute_forward_flash_attn( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - - switch (q->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_attn_f16(params, masked, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_f32(params, masked, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -16336,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext( } } -// ggml_compute_forward_flash_ff - -static void ggml_compute_forward_flash_ff_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * a = dst->src[0]; // F16 - const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w - const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b - const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w - const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, nea, a, ne) - GGML_TENSOR_LOCALS(size_t, nba, a, nb) - GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) - GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) - GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) - GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) - GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) - GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) - GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) - GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = nea0; - //const int64_t N = nea1; - const int64_t M = neb01; - - GGML_ASSERT(ne0 == nea0); - GGML_ASSERT(ne1 == nea1); - GGML_ASSERT(ne2 == nea2); - - GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb10 == sizeof(float)); - GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbc10 == sizeof(float)); - - GGML_ASSERT(neb00 == D); - GGML_ASSERT(neb01 == M); - GGML_ASSERT(neb10 == M); - GGML_ASSERT(neb11 == 1); - - GGML_ASSERT(nec00 == M); - GGML_ASSERT(nec01 == D); - GGML_ASSERT(nec10 == D); - GGML_ASSERT(nec11 == 1); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by a rows using ggml_vec_dot_f32 - - // total rows in a - const int nr = nea1*nea2*nea3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // a indices - const int ia3 = ir/(nea2*nea1); - const int ia2 = (ir - ia3*nea2*nea1)/nea1; - const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); - - float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - - for (int64_t ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; - - // S indices - const int i1 = ib01; - - ggml_vec_dot_f16(nea0, - S + i1, 0, - (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0, - (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1); - } - - ggml_vec_add_f32(neb01, S, S, (float *) b1->data); - //ggml_vec_gelu_f32(neb01, S, S); - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - ggml_vec_gelu_f16(neb01, S16, S16); - - { - // dst indices - const int i1 = ia1; - const int i2 = ia2; - const int i3 = ia3; - - for (int64_t ic = 0; ic < nec01; ++ic) { - - ggml_vec_dot_f16(neb01, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0, - S16, 0, 1); - } - - ggml_vec_add_f32(nec01, - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) c1->data); - } - } -} - -static void ggml_compute_forward_flash_ff( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * b0 = dst->src[1]; - - switch (b0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_ff_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); // TODO - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_flash_attn_back static void ggml_compute_forward_flash_attn_back_f32( @@ -18065,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; - case GGML_OP_FLASH_ATTN: - { - const int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - const bool masked = t != 0; - ggml_compute_forward_flash_attn(params, masked, tensor); - } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); } break; - case GGML_OP_FLASH_FF: - { - ggml_compute_forward_flash_ff(params, tensor); - } break; case GGML_OP_FLASH_ATTN_BACK: { int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -19086,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; @@ -19140,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; - case GGML_OP_FLASH_FF: - { - GGML_ASSERT(false); // not supported - } break; case GGML_OP_FLASH_ATTN_BACK: { GGML_ASSERT(false); // not supported @@ -19830,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; - case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; - case GGML_OP_FLASH_FF: - { - n_tasks = n_threads; - } break; case GGML_OP_FLASH_ATTN_BACK: { n_tasks = n_threads; @@ -20235,40 +19595,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; - case GGML_OP_FLASH_ATTN: - { - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); - - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_BF16) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne00 = node->src[0]->ne[0]; // D cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread } break; - case GGML_OP_FLASH_FF: - { - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_BF16) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } - } break; case GGML_OP_FLASH_ATTN_BACK: { const int64_t D = node->src[0]->ne[0]; diff --git a/ggml.h b/ggml.h index 08835042c..be81e0c52 100644 --- a/ggml.h +++ b/ggml.h @@ -481,9 +481,7 @@ extern "C" { GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, - GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN_EXT, - GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, @@ -1761,13 +1759,6 @@ extern "C" { struct ggml_tensor * a, int k); - GGML_API struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked); - #define GGML_KQ_MASK_PAD 32 // q: [n_embd, n_batch, n_head, 1] @@ -1788,6 +1779,7 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, @@ -1796,14 +1788,6 @@ extern "C" { struct ggml_tensor * d, bool masked); - GGML_API struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1); - GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * s, diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 8ff76c891..21ca43be3 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) { } // flash_attn f32 - { - srand(seed); - const int nargs = 3; + // TODO: adapt to ggml_flash_attn_ext() changes + //{ + // srand(seed); + // const int nargs = 3; - int64_t ne2[4]; + // int64_t ne2[4]; - get_random_dims(ne2, 4); - int64_t D = ne2[0]; - int64_t N = ne2[1]; - int64_t M = ne2[2] + N; - int64_t B = ne2[3]; + // get_random_dims(ne2, 4); + // int64_t D = ne2[0]; + // int64_t N = ne2[1]; + // int64_t M = ne2[2] + N; + // int64_t B = ne2[3]; - for (int masked = 0; masked <= 1; ++masked) { - for (int ndims = 2; ndims <= 4; ++ndims) { - int max_nrep = (ndims >= 3) ? 2 : 1; - for (int nrep = 1; nrep < max_nrep; ++nrep) { - int64_t neq[4] = { D, N, B*nrep, ne[3] }; - int64_t nek[4] = { D, M, B, ne[3] }; - int64_t nev[4] = { M, D, B, ne[3] }; - if (ndims == 2) { - neq[2] = 1; neq[3] = 1; - nek[2] = 1; nek[3] = 1; - nev[2] = 1; nev[3] = 1; - } else if (ndims == 3) { - neq[3] = 1; - nek[3] = 1; - nev[3] = 1; - } - x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); - x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); - x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - ggml_set_param(ctx0, x[2]); + // for (int masked = 0; masked <= 1; ++masked) { + // for (int ndims = 2; ndims <= 4; ++ndims) { + // int max_nrep = (ndims >= 3) ? 2 : 1; + // for (int nrep = 1; nrep < max_nrep; ++nrep) { + // int64_t neq[4] = { D, N, B*nrep, ne[3] }; + // int64_t nek[4] = { D, M, B, ne[3] }; + // int64_t nev[4] = { M, D, B, ne[3] }; + // if (ndims == 2) { + // neq[2] = 1; neq[3] = 1; + // nek[2] = 1; nek[3] = 1; + // nev[2] = 1; nev[3] = 1; + // } else if (ndims == 3) { + // neq[3] = 1; + // nek[3] = 1; + // nev[3] = 1; + // } + // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); + // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); + // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); + // ggml_set_param(ctx0, x[0]); + // ggml_set_param(ctx0, x[1]); + // ggml_set_param(ctx0, x[2]); - struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); + // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); - } - } - } - } + // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); + // } + // } + // } + //} - // flash_attn f16, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 3; - - int64_t ne2[4]; - - get_random_dims(ne2, 4); - int64_t D = ne2[0]; - int64_t N = ne2[1]; - int64_t M = ne2[2] + N; - int64_t B = ne2[3]; - - for (int masked = 0; masked <= 1; ++masked) { - for (int ndims = 2; ndims <= 4; ++ndims) { - int64_t neq[4] = { D, N, B, ne[3] }; - int64_t nek[4] = { D, M, B, ne[3] }; - int64_t nev[4] = { M, D, B, ne[3] }; - if (ndims == 2) { - neq[2] = 1; neq[3] = 1; - nek[2] = 1; nek[3] = 1; - nev[2] = 1; nev[3] = 1; - } else if (ndims == 3) { - neq[3] = 1; - nek[3] = 1; - nev[3] = 1; - } - x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f); - x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f); - x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f); - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - ggml_set_param(ctx0, x[2]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - - check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); - } - } - } ggml_free(ctx0); }