cuda : minor

This commit is contained in:
Georgi Gerganov 2024-02-04 09:57:58 +02:00
parent ef68fac2a8
commit 1846e92a90
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32(
} }
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc; typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
#endif #endif
// based on metal version // based on metal version
@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16(
const int iq2 = blockIdx.y; const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q; const int iq1 = blockIdx.x * Q;
const int D2 = D/2;
const int D16 = D/16; const int D16 = D/16;
const int Q16 = Q/16; const int Q16 = Q/16;
const int C16 = C/16;
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half) const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
const int C2 = C/2; const int C2 = C/2;
const int D2 = D/2;
extern __shared__ half __flash_attn_f16_shmem[]; extern __shared__ half __flash_attn_f16_shmem[];
// pq // pq
@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T // Q*K^T
{ {
#pragma unroll #pragma unroll
for (int cc = 0; cc < C/16; ++cc) { for (int cc = 0; cc < C16; ++cc) {
half16x16_acc mqk[Q16]; half16x16_acc mqk[Q16];
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0); nvcuda::wmma::fill_fragment(mqk[j], 0);
@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16(
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
for (int cc = 0; cc < C/16; ++cc) { for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16]; half16x16_b mv[D16];
@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16(
} }
// these are needed for reducing the results from the simdgroups (reuse the ss buffer) // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int j = 0; j < Q; ++j) { if (lane_id < Q) {
if (lane_id == j) { ss[lane_id*T + 0] = S;
ss[j*T + 0] = S; ss[lane_id*T + 1] = M[lane_id];
ss[j*T + 1] = M[j];
}
} }
} }
@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// increase shared memory limit to 96KB
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
flash_attn_ext_f16<64, NQPB, NCPW> flash_attn_ext_f16<64, NQPB, NCPW>
@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
default: default:
break; break;
} }
CUDA_CHECK(cudaGetLastError());
} }
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {