mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 14:28:58 +01:00
ggml : online attention (CPU)
This commit is contained in:
parent
c3cdfffa88
commit
a9681febd6
@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:21];
|
||||
|
||||
const int nwarps = 4;
|
||||
|
||||
// each warp needs n_embd_head elements
|
||||
GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
|
@ -1982,6 +1982,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant float & scale,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
257
ggml.c
257
ggml.c
@ -817,7 +817,7 @@ do { \
|
||||
|
||||
#if defined(__F16C__)
|
||||
// the _mm256_cvt intrinsics require F16C
|
||||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
|
||||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
||||
#else
|
||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||
@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// xs and vs are byte strides of x and v
|
||||
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
||||
|
||||
@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||
@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
||||
// permute(0, 2, 1, 3)
|
||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
|
||||
|
||||
float params[] = { scale };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int64_t D = neq0;
|
||||
const int64_t N = neq1;
|
||||
const int64_t P = nek1 - N;
|
||||
const int64_t M = P + N;
|
||||
|
||||
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne1 == N);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
GGML_ASSERT(P >= 0);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
||||
@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nek1 == N + P);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
||||
float S = 0.0f;
|
||||
float M = -INFINITY;
|
||||
|
||||
for (int i = M; i < Mup; ++i) {
|
||||
S[i] = -INFINITY;
|
||||
}
|
||||
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
||||
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
||||
|
||||
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
||||
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||
|
||||
const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik1 = ic;
|
||||
const float mv = mp ? mp[ic] : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// S indices
|
||||
const int i1 = ik1;
|
||||
float s;
|
||||
|
||||
ggml_vec_dot_f16(neq0,
|
||||
S + i1,
|
||||
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
ggml_vec_dot_f16(D,
|
||||
&s,
|
||||
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
}
|
||||
|
||||
s = s*scale + mv;
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(D, V16, ms);
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik1 = ic;
|
||||
|
||||
// S indices
|
||||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
||||
S + i1,
|
||||
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
}
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// scale
|
||||
ggml_vec_scale_f32(nek1, S, scale);
|
||||
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (mask) {
|
||||
const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]);
|
||||
ggml_vec_acc_f32(M, S, mp);
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(D, V16, v16, vs);
|
||||
|
||||
S = S*ms + vs;
|
||||
}
|
||||
|
||||
// softmax
|
||||
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
||||
// dont forget to set their S values to zero
|
||||
{
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(M, &max, S);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
{
|
||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||
max = -max;
|
||||
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
||||
vvexpf(S, S, &Mup);
|
||||
ggml_vec_sum_f32(Mup, &sum, S);
|
||||
#else
|
||||
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||
|
||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||
float * SS = S + i;
|
||||
|
||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||
if (SS[j] == -INFINITY) {
|
||||
SS[j] = 0.0f;
|
||||
} else {
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
}
|
||||
}
|
||||
// V /= S
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||
sum += sump[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(M, S, sum);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < M; ++i) {
|
||||
assert(!isnan(S[i]));
|
||||
assert(!isinf(S[i]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
||||
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
||||
}
|
||||
|
||||
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
||||
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
||||
for (int64_t ic = 0; ic < nev1; ++ic) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
// original
|
||||
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||
|
||||
ggml_vec_dot_f16(nev0,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S16);
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S16);
|
||||
}
|
||||
}
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||
|
||||
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
{
|
||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
||||
|
5
ggml.h
5
ggml.h
@ -1620,6 +1620,11 @@ extern "C" {
|
||||
struct ggml_tensor * v,
|
||||
bool masked);
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
||||
// mask: [n_kv, n_batch, 1, 1]
|
||||
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
68
llama.cpp
68
llama.cpp
@ -95,6 +95,8 @@
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_EXPERTS 8
|
||||
|
||||
#define LLAMA_FLASH_ATTN
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
@ -4167,23 +4169,34 @@ static void llm_build_kv_store(
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
||||
cb(v_cur_t, "v_cur_t", il);
|
||||
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
|
||||
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
|
||||
cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||
|
||||
#if defined(LLAMA_FLASH_ATTN)
|
||||
// NOTE: the V cache is not transposed when using FLASH attention !!
|
||||
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
||||
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
|
||||
cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
||||
|
||||
GGML_UNUSED(n_ctx);
|
||||
#else
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
||||
cb(v_cur_t, "v_cur_t", il);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
||||
(kv_head)*ggml_element_size(kv.v_l[il]));
|
||||
cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
||||
#endif
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_norm(
|
||||
@ -4343,27 +4356,27 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
#if defined(LLAMA_FLASH_ATTN)
|
||||
// split cached v into n_head heads (not transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
// TODO: determine if we can use flash attention
|
||||
const bool supports_flash_attn = true;
|
||||
|
||||
struct ggml_tensor * kqv;
|
||||
|
||||
if (supports_flash_attn) {
|
||||
cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
|
||||
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
|
||||
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
|
||||
} else {
|
||||
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
||||
#else
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
@ -4396,15 +4409,24 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
}
|
||||
|
||||
kqv = ggml_mul_mat(ctx, v, kq);
|
||||
// split cached v into n_head heads (transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
#endif
|
||||
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
if (wo_b) {
|
||||
|
@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case {
|
||||
const int64_t hs; // head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t kv; // kv size
|
||||
const int64_t nt; // tokens
|
||||
const int64_t nb; // batch size
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(typeq, hs, nh, kv, nt);
|
||||
return VARS_TO_STR5(typeq, hs, nh, kv, nb);
|
||||
}
|
||||
|
||||
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
|
||||
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8)
|
||||
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {}
|
||||
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||
: typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1);
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1);
|
||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||
return out;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user