mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
cuda : bump nwarps by 1
This commit is contained in:
parent
08e69c5008
commit
5dd355fe26
@ -471,120 +471,127 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) + 1 : 1;
|
||||
|
||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||
dim3 block_dim(32, nwarps, 1);
|
||||
|
||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||
|
||||
// increase shared memory limit to 96KB
|
||||
//const size_t shmem_max = 96*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
//printf("nwarps: %d, shm: %zu\n", nwarps, shmem);
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
case 80:
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
case 96:
|
||||
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
case 112:
|
||||
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
case 128:
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
case 256:
|
||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
{
|
||||
// increase shared memory limit to 64KB
|
||||
//const size_t shmem_max = 64*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? (const char *) mask->data : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user