mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 22:08:55 +01:00
cuda : reduce registers
This commit is contained in:
parent
5dd355fe26
commit
4c190ba676
@ -175,12 +175,12 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int iv3 = iq3 / rv3;
|
const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
half16x16_a mq[Q16][D16];
|
//half16x16_a mq[Q16][D16];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
//for (int j = 0; j < Q16; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
// for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
// nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||||
@ -216,7 +216,9 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||||
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
half16x16_a mq;
|
||||||
|
nvcuda::wmma::load_matrix_sync(mq, sq + 16*j*T + i*16, T);
|
||||||
|
nvcuda::wmma::mma_sync(mqk[j], mq, mk, mqk[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,19 +321,13 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
for (int cc = 0; cc < C16; ++cc) {
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half16x16_b mv[D16];
|
|
||||||
for (int i = 0; i < D16; ++i) {
|
|
||||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
|
||||||
}
|
|
||||||
|
|
||||||
half16x16_a ms[Q16];
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
|
||||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
half16x16_a ms;
|
||||||
|
nvcuda::wmma::load_matrix_sync(ms, ss + 16*j*T + 16*cc, T);
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
half16x16_b mv;
|
||||||
|
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
|
||||||
|
nvcuda::wmma::mma_sync(lo[j][i], ms, mv, lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -554,6 +550,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
} break;
|
} break;
|
||||||
case 128:
|
case 128:
|
||||||
{
|
{
|
||||||
|
//const size_t shmem_max = 96*1024;
|
||||||
|
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
|
|
||||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data, // Query
|
(const char *) Q->data, // Query
|
||||||
@ -572,9 +571,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
} break;
|
} break;
|
||||||
case 256:
|
case 256:
|
||||||
{
|
{
|
||||||
// increase shared memory limit to 64KB
|
const size_t shmem_max = 64*1024;
|
||||||
//const size_t shmem_max = 64*1024;
|
cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
|
||||||
|
|
||||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
Loading…
Reference in New Issue
Block a user