mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
ggml : add ggml_flash_attn_ext API
This commit is contained in:
parent
ad19812cda
commit
a1c004ef2e
50
ggml-metal.m
50
ggml-metal.m
@ -147,6 +147,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
@ -511,6 +512,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_src3 = 0;
|
||||
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil;
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline;
|
||||
|
||||
// TODO: extend if necessary
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:21];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32(
|
||||
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
||||
}
|
||||
|
||||
kernel void kernel_flash_attn_ext_f16(
|
||||
device const half * q,
|
||||
device const half * k,
|
||||
device const half * v,
|
||||
device const half * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant float & scale,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
298
ggml.c
298
ggml.c
@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"LEAKY_RELU",
|
||||
|
||||
"FLASH_ATTN",
|
||||
"FLASH_ATTN_EXT",
|
||||
"FLASH_FF",
|
||||
"FLASH_ATTN_BACK",
|
||||
"WIN_PART",
|
||||
@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"leaky_relu(x)",
|
||||
|
||||
"flash_attn(x)",
|
||||
"flash_attn_ext(x)",
|
||||
"flash_ff(x)",
|
||||
"flash_attn_back(x)",
|
||||
"win_part(x)",
|
||||
@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_flash_attn_ext
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
float scale) {
|
||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||
// TODO: check if vT can be multiplied by (k*qT)
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == 1);
|
||||
GGML_ASSERT(mask->ne[3] == 1);
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (q->grad || k->grad || v->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
||||
|
||||
float params[] = { scale };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k;
|
||||
result->src[2] = v;
|
||||
result->src[3] = mask;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_flash_ff
|
||||
|
||||
struct ggml_tensor * ggml_flash_ff(
|
||||
@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
int64_t t0 = ggml_perf_time_us();
|
||||
UNUSED(t0);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t D = neq0;
|
||||
const int64_t N = neq1;
|
||||
const int64_t P = nek1 - N;
|
||||
const int64_t M = P + N;
|
||||
|
||||
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne1 == N);
|
||||
GGML_ASSERT(P >= 0);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nek1 == N + P);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// parallelize by q rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in q
|
||||
const int nr = neq1*neq2*neq3;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
||||
|
||||
for (int i = M; i < Mup; ++i) {
|
||||
S[i] = -INFINITY;
|
||||
}
|
||||
|
||||
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
// k indices
|
||||
const int ik3 = iq3;
|
||||
const int ik2 = iq2 % nek2;
|
||||
const int ik1 = ic;
|
||||
|
||||
// S indices
|
||||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f16(neq0,
|
||||
S + i1,
|
||||
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
// k indices
|
||||
const int ik3 = iq3;
|
||||
const int ik2 = iq2 % nek2;
|
||||
const int ik1 = ic;
|
||||
|
||||
// S indices
|
||||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
||||
S + i1,
|
||||
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
}
|
||||
}
|
||||
|
||||
// scale
|
||||
ggml_vec_scale_f32(nek1, S, scale);
|
||||
|
||||
if (mask) {
|
||||
const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]);
|
||||
ggml_vec_acc_f32(M, S, mp);
|
||||
}
|
||||
|
||||
// softmax
|
||||
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
||||
// dont forget to set their S values to zero
|
||||
{
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(M, &max, S);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
{
|
||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||
max = -max;
|
||||
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
||||
vvexpf(S, S, &Mup);
|
||||
ggml_vec_sum_f32(Mup, &sum, S);
|
||||
#else
|
||||
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||
|
||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||
float * SS = S + i;
|
||||
|
||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||
if (SS[j] == -INFINITY) {
|
||||
SS[j] = 0.0f;
|
||||
} else {
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||
sum += sump[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(M, S, sum);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < M; ++i) {
|
||||
assert(!isnan(S[i]));
|
||||
assert(!isinf(S[i]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
||||
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
||||
}
|
||||
|
||||
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
||||
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
||||
for (int64_t ic = 0; ic < nev1; ++ic) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 % nev2;
|
||||
const int iv3 = iq3;
|
||||
|
||||
ggml_vec_dot_f16(nev0,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S16);
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 % nev2;
|
||||
const int iv3 = iq3;
|
||||
|
||||
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S16);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (q->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_ff
|
||||
|
||||
static void ggml_compute_forward_flash_ff_f16(
|
||||
@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
const bool masked = t != 0;
|
||||
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
{
|
||||
ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
|
||||
@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
struct ggml_tensor * flash_grad = NULL;
|
||||
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
||||
@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
|
9
ggml.h
9
ggml.h
@ -452,6 +452,7 @@ extern "C" {
|
||||
GGML_OP_LEAKY_RELU,
|
||||
|
||||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_FLASH_FF,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
GGML_OP_WIN_PART,
|
||||
@ -1619,6 +1620,14 @@ extern "C" {
|
||||
struct ggml_tensor * v,
|
||||
bool masked);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
float scale);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
32
llama.cpp
32
llama.cpp
@ -4205,6 +4205,23 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
// TODO: determine if we can use flash attention
|
||||
const bool supports_flash_attn = true;
|
||||
|
||||
struct ggml_tensor * kqv;
|
||||
|
||||
if (supports_flash_attn) {
|
||||
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
@ -4237,17 +4254,9 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
}
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||
kqv = ggml_mul_mat(ctx, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
ctx->backends.push_back(ctx->backend_cpu);
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v,
|
||||
cparams.n_ctx, cparams.offload_kqv)) {
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
|
@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_FLASH_ATTN_EXT
|
||||
struct test_flash_attn_ext : public test_case {
|
||||
const ggml_type typeq;
|
||||
const int64_t hs; // head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t kv; // kv size
|
||||
const int64_t nt; // tokens
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(typeq, hs, nh, kv, nt);
|
||||
}
|
||||
|
||||
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
|
||||
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8)
|
||||
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1);
|
||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Mixtral MOE
|
||||
struct test_moe : public test_case {
|
||||
const int n_experts;
|
||||
@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8));
|
||||
|
||||
#if !defined(__SANITIZE_THREAD__)
|
||||
// FIXME: these tests use too much memory with thread sanitizer
|
||||
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
|
||||
|
Loading…
Reference in New Issue
Block a user