mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
cuda : reduce registers
This commit is contained in:
parent
5dd355fe26
commit
4c190ba676
@ -175,12 +175,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half16x16_a mq[Q16][D16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||
}
|
||||
}
|
||||
//half16x16_a mq[Q16][D16];
|
||||
//for (int j = 0; j < Q16; ++j) {
|
||||
// for (int i = 0; i < D16; ++i) {
|
||||
// nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||
// }
|
||||
//}
|
||||
|
||||
// pointer to the mask
|
||||
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||
@ -216,7 +216,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
half16x16_a mq;
|
||||
nvcuda::wmma::load_matrix_sync(mq, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq, mk, mqk[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -319,19 +321,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mv[D16];
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
||||
}
|
||||
|
||||
half16x16_a ms[Q16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a ms;
|
||||
nvcuda::wmma::load_matrix_sync(ms, ss + 16*j*T + 16*cc, T);
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||
half16x16_b mv;
|
||||
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms, mv, lo[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -554,6 +550,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
} break;
|
||||
case 128:
|
||||
{
|
||||
//const size_t shmem_max = 96*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
@ -572,9 +571,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
} break;
|
||||
case 256:
|
||||
{
|
||||
// increase shared memory limit to 64KB
|
||||
//const size_t shmem_max = 64*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
const size_t shmem_max = 64*1024;
|
||||
cudaFuncSetAttribute(flash_attn_ext_f16<256, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
|
Loading…
Reference in New Issue
Block a user