mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : add tests, fix scaling, support C > 32
This commit is contained in:
parent
77f6976a87
commit
ecc466a460
@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8)
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32)
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const int64_t D4 = D/4;
|
const int64_t D4 = D/4;
|
||||||
const int64_t D8 = D/8;
|
const int64_t D8 = D/8;
|
||||||
const int64_t NW = N_SIMDWIDTH;
|
const int64_t NW = N_SIMDWIDTH;
|
||||||
const int64_t L4 = (D4 + NW - 1)/NW;
|
|
||||||
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
|
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
|
||||||
@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
simdgroup_half8x8 lo[D8];
|
simdgroup_half8x8 lo[D8];
|
||||||
|
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
// load heads from Q to shared memory
|
||||||
// load heads from Q to shared memory
|
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
|
||||||
|
for (int64_t i = tiisg; i < D4; i += NW) {
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg];
|
sq4[j*T4 + i] = (half4) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[j*T4 + NW*i + tiisg] = 0.0h;
|
sq4[j*T4 + i] = 0.0h;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// zero out shared memory SH
|
// zero out shared memory SH
|
||||||
if (tiisg < C) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t i = tiisg; i < SH; i += NW) {
|
||||||
ss[j*T + tiisg] = 0.0h;
|
ss[j*T + i] = 0.0h;
|
||||||
if (tiisg < Q) {
|
|
||||||
ss[j*T + C + tiisg] = 0.0h;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t p = tiisg;
|
|
||||||
|
|
||||||
const half s = ss[j*T + p];
|
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
|
||||||
|
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
S[j] = S[j]*ms + simd_sum(vs);
|
smax = simd_max(max(smax, s));
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
}
|
||||||
|
|
||||||
|
const half ms = exp(m - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms;
|
||||||
|
|
||||||
// create an 8x8 diagonal matrix for rescaling the output
|
// create an 8x8 diagonal matrix for rescaling the output
|
||||||
if (p == j) {
|
if (tiisg == j) {
|
||||||
ss[j*T + C + j] = ms;
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
ss[j*T + p] = vs;
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
const half vs = exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j] + simd_sum(vs);
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[j*T + p] = vs;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq4
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
if (sgitg == sg) {
|
if (sgitg == sg) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||||
@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// store result to shared memory (reuse sq4)
|
// store result to shared memory (reuse sq)
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||||
@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
const half S = ss[j*T + 0];
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = tiisg; i < D4; i += NW) {
|
||||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S;
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 5e-5;
|
return 5e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||||
@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8));
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7));
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1));
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1));
|
||||||
|
|
||||||
#if !defined(__SANITIZE_THREAD__)
|
#if !defined(__SANITIZE_THREAD__)
|
||||||
// FIXME: these tests use too much memory with thread sanitizer
|
// FIXME: these tests use too much memory with thread sanitizer
|
||||||
|
Loading…
Reference in New Issue
Block a user