mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
parent
e402de364b
commit
6369bf0433
@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
float slope = 1.0f;
|
||||||
simdgroup_float8x8 mscale(scale);
|
|
||||||
|
|
||||||
// prepare diagonal slope matrix
|
|
||||||
simdgroup_float8x8 mslope(1.0f);
|
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const float base = h < n_head_log2 ? m0 : m1;
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
mslope = simdgroup_float8x8(pow(base, exph));
|
slope = pow(base, exph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
|
|
||||||
|
const short tx = tiisg%4;
|
||||||
|
const short ty = tiisg/4;
|
||||||
|
|
||||||
if (mask != q) {
|
if (mask != q) {
|
||||||
// mqk = mqk*scale + mask*slope
|
// mqk = mqk*scale + mask*slope
|
||||||
simdgroup_half8x8 mm;
|
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
||||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
||||||
simdgroup_multiply(mm, mslope, mm);
|
|
||||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
|
||||||
} else {
|
} else {
|
||||||
// mqk = mqk*scale
|
// mqk = mqk*scale
|
||||||
simdgroup_multiply(mqk, mscale, mqk);
|
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
|
|||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
// TODO: is there a better way to handle -INFINITY?
|
dst_data[i00] = src[0];
|
||||||
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user