mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 14:28:58 +01:00
cuda : bump nwarps by 1
This commit is contained in:
parent
08e69c5008
commit
5dd355fe26
@ -471,120 +471,127 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
|
|
||||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) + 1 : 1;
|
||||||
|
|
||||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
dim3 block_dim(32, nwarps, 1);
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
|
||||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||||
|
//printf("nwarps: %d, shm: %zu\n", nwarps, shmem);
|
||||||
// increase shared memory limit to 96KB
|
|
||||||
//const size_t shmem_max = 96*1024;
|
|
||||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
|
||||||
|
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
(const char *) Q->data, // Query
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) K->data, // Key
|
(const char *) Q->data, // Query
|
||||||
(const char *) V->data, // Value
|
(const char *) K->data, // Key
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
(const char *) V->data, // Value
|
||||||
(float *) KQV->data, // dst
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
scale,
|
(float *) KQV->data, // dst
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
scale,
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
);
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
break;
|
);
|
||||||
|
} break;
|
||||||
case 80:
|
case 80:
|
||||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||||
(const char *) Q->data, // Query
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) K->data, // Key
|
(const char *) Q->data, // Query
|
||||||
(const char *) V->data, // Value
|
(const char *) K->data, // Key
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
(const char *) V->data, // Value
|
||||||
(float *) KQV->data, // dst
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
scale,
|
(float *) KQV->data, // dst
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
scale,
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
);
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
break;
|
);
|
||||||
|
} break;
|
||||||
case 96:
|
case 96:
|
||||||
flash_attn_ext_f16<96, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||||
(const char *) Q->data, // Query
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) K->data, // Key
|
(const char *) Q->data, // Query
|
||||||
(const char *) V->data, // Value
|
(const char *) K->data, // Key
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
(const char *) V->data, // Value
|
||||||
(float *) KQV->data, // dst
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
scale,
|
(float *) KQV->data, // dst
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
scale,
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
);
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
break;
|
);
|
||||||
|
} break;
|
||||||
case 112:
|
case 112:
|
||||||
flash_attn_ext_f16<112, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||||
(const char *) Q->data, // Query
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) K->data, // Key
|
(const char *) Q->data, // Query
|
||||||
(const char *) V->data, // Value
|
(const char *) K->data, // Key
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
(const char *) V->data, // Value
|
||||||
(float *) KQV->data, // dst
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
scale,
|
(float *) KQV->data, // dst
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
scale,
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
);
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
break;
|
);
|
||||||
|
} break;
|
||||||
case 128:
|
case 128:
|
||||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||||
(const char *) Q->data, // Query
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) K->data, // Key
|
(const char *) Q->data, // Query
|
||||||
(const char *) V->data, // Value
|
(const char *) K->data, // Key
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
(const char *) V->data, // Value
|
||||||
(float *) KQV->data, // dst
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
scale,
|
(float *) KQV->data, // dst
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
scale,
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
);
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
break;
|
);
|
||||||
|
} break;
|
||||||
case 256:
|
case 256:
|
||||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
{
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
// increase shared memory limit to 64KB
|
||||||
(const char *) Q->data, // Query
|
//const size_t shmem_max = 64*1024;
|
||||||
(const char *) K->data, // Key
|
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
(const char *) V->data, // Value
|
|
||||||
mask ? (const char *) mask->data : nullptr, // Mask
|
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||||
(float *) KQV->data, // dst
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
scale,
|
(const char *) Q->data, // Query
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
(const char *) K->data, // Key
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
(const char *) V->data, // Value
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? (const char *) mask->data : nullptr, // Mask
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
(float *) KQV->data, // dst
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
scale,
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
);
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
break;
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user