cuda : bump nwarps by 1

This commit is contained in:
Georgi Gerganov 2024-03-28 20:21:09 +02:00
parent 08e69c5008
commit 5dd355fe26
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -471,120 +471,127 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) + 1 : 1;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32, nwarps, 1);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
//printf("nwarps: %d, shm: %zu\n", nwarps, shmem);
// increase shared memory limit to 96KB
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
flash_attn_ext_f16<64, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( flash_attn_ext_f16<64, NQPB, NCPW>
(const char *) Q->data, // Query <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) K->data, // Key (const char *) Q->data, // Query
(const char *) V->data, // Value (const char *) K->data, // Key
mask ? (const char *) mask->data : nullptr, // Mask (const char *) V->data, // Value
(float *) KQV->data, // dst mask ? (const char *) mask->data : nullptr, // Mask
scale, (float *) KQV->data, // dst
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], scale,
K->ne[0], K->ne[1], K->ne[2], K->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Q->nb[1], Q->nb[2], Q->nb[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
K->nb[1], K->nb[2], K->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] K->nb[1], K->nb[2], K->nb[3],
); KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
break; );
} break;
case 80: case 80:
flash_attn_ext_f16<80, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( flash_attn_ext_f16<80, NQPB, NCPW>
(const char *) Q->data, // Query <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) K->data, // Key (const char *) Q->data, // Query
(const char *) V->data, // Value (const char *) K->data, // Key
mask ? (const char *) mask->data : nullptr, // Mask (const char *) V->data, // Value
(float *) KQV->data, // dst mask ? (const char *) mask->data : nullptr, // Mask
scale, (float *) KQV->data, // dst
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], scale,
K->ne[0], K->ne[1], K->ne[2], K->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Q->nb[1], Q->nb[2], Q->nb[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
K->nb[1], K->nb[2], K->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] K->nb[1], K->nb[2], K->nb[3],
); KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
break; );
} break;
case 96: case 96:
flash_attn_ext_f16<96, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( flash_attn_ext_f16<96, NQPB, NCPW>
(const char *) Q->data, // Query <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) K->data, // Key (const char *) Q->data, // Query
(const char *) V->data, // Value (const char *) K->data, // Key
mask ? (const char *) mask->data : nullptr, // Mask (const char *) V->data, // Value
(float *) KQV->data, // dst mask ? (const char *) mask->data : nullptr, // Mask
scale, (float *) KQV->data, // dst
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], scale,
K->ne[0], K->ne[1], K->ne[2], K->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Q->nb[1], Q->nb[2], Q->nb[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
K->nb[1], K->nb[2], K->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] K->nb[1], K->nb[2], K->nb[3],
); KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
break; );
} break;
case 112: case 112:
flash_attn_ext_f16<112, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( flash_attn_ext_f16<112, NQPB, NCPW>
(const char *) Q->data, // Query <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) K->data, // Key (const char *) Q->data, // Query
(const char *) V->data, // Value (const char *) K->data, // Key
mask ? (const char *) mask->data : nullptr, // Mask (const char *) V->data, // Value
(float *) KQV->data, // dst mask ? (const char *) mask->data : nullptr, // Mask
scale, (float *) KQV->data, // dst
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], scale,
K->ne[0], K->ne[1], K->ne[2], K->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Q->nb[1], Q->nb[2], Q->nb[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
K->nb[1], K->nb[2], K->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] K->nb[1], K->nb[2], K->nb[3],
); KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
break; );
} break;
case 128: case 128:
flash_attn_ext_f16<128, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( flash_attn_ext_f16<128, NQPB, NCPW>
(const char *) Q->data, // Query <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) K->data, // Key (const char *) Q->data, // Query
(const char *) V->data, // Value (const char *) K->data, // Key
mask ? (const char *) mask->data : nullptr, // Mask (const char *) V->data, // Value
(float *) KQV->data, // dst mask ? (const char *) mask->data : nullptr, // Mask
scale, (float *) KQV->data, // dst
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], scale,
K->ne[0], K->ne[1], K->ne[2], K->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Q->nb[1], Q->nb[2], Q->nb[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
K->nb[1], K->nb[2], K->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] K->nb[1], K->nb[2], K->nb[3],
); KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
break; );
} break;
case 256: case 256:
flash_attn_ext_f16<256, NQPB, NCPW> {
<<<blocks_num, block_dim, shmem, main_stream>>> ( // increase shared memory limit to 64KB
(const char *) Q->data, // Query //const size_t shmem_max = 64*1024;
(const char *) K->data, // Key //cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask flash_attn_ext_f16<256, NQPB, NCPW>
(float *) KQV->data, // dst <<<blocks_num, block_dim, shmem, main_stream>>> (
scale, (const char *) Q->data, // Query
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], (const char *) K->data, // Key
K->ne[0], K->ne[1], K->ne[2], K->ne[3], (const char *) V->data, // Value
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, mask ? (const char *) mask->data : nullptr, // Mask
Q->nb[1], Q->nb[2], Q->nb[3], (float *) KQV->data, // dst
K->nb[1], K->nb[2], K->nb[3], scale,
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
); K->ne[0], K->ne[1], K->ne[2], K->ne[3],
break; mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
} break;
default: default:
break; break;
} }