mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
cuda : unroll Q*K^T loop
This commit is contained in:
parent
3b1c4e7673
commit
5b263dd83a
@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
|
#pragma unroll
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
half16x16_acc mqk[Q16];
|
half16x16_acc mqk[Q16];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
Loading…
Reference in New Issue
Block a user