cuda : reduce registers

This commit is contained in:
Georgi Gerganov 2024-03-28 21:17:08 +02:00
parent 5dd355fe26
commit 4c190ba676
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -175,12 +175,12 @@ static __global__ void flash_attn_ext_f16(
const int iv3 = iq3 / rv3; const int iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
half16x16_a mq[Q16][D16]; //half16x16_a mq[Q16][D16];
for (int j = 0; j < Q16; ++j) { //for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) { // for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); // nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
} // }
} //}
// pointer to the mask // pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
@ -216,7 +216,9 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); half16x16_a mq;
nvcuda::wmma::load_matrix_sync(mq, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(mqk[j], mq, mk, mqk[j]);
} }
} }
@ -319,19 +321,13 @@ static __global__ void flash_attn_ext_f16(
for (int cc = 0; cc < C16; ++cc) { for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16];
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a ms[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
}
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < Q16; ++j) {
half16x16_a ms;
nvcuda::wmma::load_matrix_sync(ms, ss + 16*j*T + 16*cc, T);
for (int i = 0; i < D16; ++i) { for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); half16x16_b mv;
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
nvcuda::wmma::mma_sync(lo[j][i], ms, mv, lo[j][i]);
} }
} }
} }
@ -554,6 +550,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} break; } break;
case 128: case 128:
{ {
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
flash_attn_ext_f16<128, NQPB, NCPW> flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query (const char *) Q->data, // Query
@ -572,9 +571,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} break; } break;
case 256: case 256:
{ {
// increase shared memory limit to 64KB const size_t shmem_max = 64*1024;
//const size_t shmem_max = 64*1024; cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
flash_attn_ext_f16<256, NQPB, NCPW> flash_attn_ext_f16<256, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (