cuda : bump nwarps by 1

This commit is contained in:
Georgi Gerganov 2024-03-28 20:21:09 +02:00
parent 08e69c5008
commit 5dd355fe26
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -471,19 +471,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) + 1 : 1;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// increase shared memory limit to 96KB
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
//printf("nwarps: %d, shm: %zu\n", nwarps, shmem);
switch (Q->ne[0]) {
case 64:
{
flash_attn_ext_f16<64, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -499,8 +497,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
case 80:
{
flash_attn_ext_f16<80, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -516,8 +515,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
case 96:
{
flash_attn_ext_f16<96, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -533,8 +533,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
case 112:
{
flash_attn_ext_f16<112, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -550,8 +551,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
case 128:
{
flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -567,8 +569,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
case 256:
{
// increase shared memory limit to 64KB
//const size_t shmem_max = 64*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
flash_attn_ext_f16<256, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
@ -584,7 +591,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
} break;
default:
break;
}