cuda : minor

This commit is contained in:
Georgi Gerganov 2024-02-04 09:57:58 +02:00
parent ef68fac2a8
commit 1846e92a90
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16(
const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q;
const int D2 = D/2;
const int D16 = D/16;
const int Q16 = Q/16;
const int C16 = C/16;
const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
const int C2 = C/2;
const int D2 = D/2;
extern __shared__ half __flash_attn_f16_shmem[];
// pq
@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T
{
#pragma unroll
for (int cc = 0; cc < C/16; ++cc) {
for (int cc = 0; cc < C16; ++cc) {
half16x16_acc mqk[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16(
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/16; ++cc) {
for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16];
@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16(
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int j = 0; j < Q; ++j) {
if (lane_id == j) {
ss[j*T + 0] = S;
ss[j*T + 1] = M[j];
}
if (lane_id < Q) {
ss[lane_id*T + 0] = S;
ss[lane_id*T + 1] = M[lane_id];
}
}
@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// increase shared memory limit to 96KB
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
switch (Q->ne[0]) {
case 64:
flash_attn_ext_f16<64, NQPB, NCPW>
@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
default:
break;
}
CUDA_CHECK(cudaGetLastError());
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {