metal : gemma2 flash attention support (#9159)

This commit is contained in:
slaren 2024-08-26 11:08:59 +02:00 committed by GitHub
parent f12ceaca0c
commit 0c41e03ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 44 deletions

View File

@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
if (op->src[0]->ne[0] == 256) { if (op->src[0]->ne[0] == 256) {
return false; return false;
} }
{
float logit_softcap;
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
if (logit_softcap != 0.0f) {
return false;
}
}
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
float scale; float scale;
float max_bias; float max_bias;
float logit_softcap;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
const uint32_t n_head = src0->ne[2]; const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@ -2710,6 +2706,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
if (!use_vec_kernel) { if (!use_vec_kernel) {
// half8x8 kernel // half8x8 kernel

View File

@ -1976,6 +1976,7 @@ typedef void (flash_attn_ext_f16_t)(
constant float & m0, constant float & m0,
constant float & m1, constant float & m1,
constant uint32_t & n_head_log2, constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared, threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2014,6 +2015,7 @@ kernel void kernel_flash_attn_ext_f16(
constant float & m0, constant float & m0,
constant float & m1, constant float & m1,
constant uint32_t & n_head_log2, constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2142,14 +2144,19 @@ kernel void kernel_flash_attn_ext_f16(
const short tx = tiisg%4; const short tx = tiisg%4;
const short ty = tiisg/4; const short ty = tiisg/4;
if (mask != q) {
// mqk = mqk*scale + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
} else {
// mqk = mqk*scale // mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale; ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale; ss[8*cc + ty*TF + 2*tx + 1] *= scale;
if (logit_softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
}
if (mask != q) {
// mqk = mqk + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
} }
} }
} }
@ -2345,6 +2352,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant float & m0, constant float & m0,
constant float & m1, constant float & m1,
constant uint32_t & n_head_log2, constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2479,7 +2487,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
if (tiisg == 0) { if (tiisg == 0) {
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); mqk *= scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
ss4[cc] = mqk; ss4[cc] = mqk;
} }

View File

@ -2487,7 +2487,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
return false;
} }
static void usage(char ** argv) { static void usage(char ** argv) {