cuda : unroll Q*K^T loop

This commit is contained in:
Georgi Gerganov 2024-02-03 16:12:20 +02:00
parent 3b1c4e7673
commit 5b263dd83a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T
{
#pragma unroll
for (int cc = 0; cc < C/16; ++cc) {
half16x16_acc mqk[Q16];
for (int j = 0; j < Q16; ++j) {