From 5ca49cbecda27ce0a7266658fc3b640bff3ed386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 May 2024 16:46:13 +0200 Subject: [PATCH] ggml: implement quantized KV cache for FA (#7372) --- ggml.c | 115 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 43 deletions(-) diff --git a/ggml.c b/ggml.c index a04c74ddd..3a104c486 100644 --- a/ggml.c +++ b/ggml.c @@ -15882,9 +15882,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); @@ -15938,6 +15939,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type; + ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float; + ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; + ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -15945,17 +15951,22 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - const uint32_t h = iq2; // head + const uint32_t h = iq2; // head index const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - float S = 0.0f; - float M = -INFINITY; + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value - float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); - ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory - ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 - memset(V16, 0, D*sizeof(ggml_fp16_t)); + if (v->type == GGML_TYPE_F16) { + memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + } else { + memset(VKQ32, 0, D*sizeof(float)); + } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -15967,6 +15978,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + q_to_vec_dot(pq, Q_q, D); + // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf @@ -15976,51 +15990,66 @@ static void ggml_compute_forward_flash_attn_ext_f16( continue; } - float s; + float s; // KQ value - // convert Q to F16 in V32 - { - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); + kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - for (int64_t d = 0; d < D; ++d) { - Q16[d] = GGML_FP32_TO_FP16(pq[d]); - } - } - - ggml_vec_dot_f16(D, - &s, 0, - (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - Q16, 0, 1); - - s = s*scale + mv; + s = s*scale + mv; // scale KQ value and apply mask const float Mold = M; - float ms = 1.0f; - float vs = 1.0f; + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) - if (s > M) { - M = s; - ms = expf(Mold - M); + const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, V16, ms); + if (v->type== GGML_TYPE_F16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, VKQ16, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { - vs = expf(s - M); + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f32(D, VKQ32, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + v_to_float(v_data, V32, D); + + // V += v*expf(s - M) + ggml_vec_mad_f32(D, VKQ32, V32, vs); } - const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + S = S*ms + vs; // scale and increment sum with partial sum + } - // V += v*expf(s - M) - ggml_vec_mad_f16(D, V16, v16, vs); - - S = S*ms + vs; + if (v->type == GGML_TYPE_F16) { + for (int64_t d = 0; d < D; ++d) { + VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); + } } // V /= S - for (int64_t d = 0; d < D; ++d) { - V32[d] = GGML_FP16_TO_FP32(V16[d])/S; - } + const float S_inv = 1.0f/S; + ggml_vec_scale_f32(D, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -16031,7 +16060,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); } } @@ -19972,7 +20001,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa { const int64_t ne00 = node->src[0]->ne[0]; // D - cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread } break; case GGML_OP_FLASH_FF: {