From 1ca802a3e0314b977b4e53ad0dd729fc2f9f487c Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 28 May 2024 01:19:36 +0200 Subject: [PATCH] parallelize fattn compilation test --- Makefile | 2 +- ggml-cuda/fattn-common.cuh | 378 ++++++++++++++++++++++++++- ggml-cuda/fattn-vec-f16-f16.cu | 7 + ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu | 5 + ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu | 5 + ggml-cuda/fattn-vec-f16.cu | 281 -------------------- 6 files changed, 390 insertions(+), 288 deletions(-) create mode 100644 ggml-cuda/fattn-vec-f16-f16.cu create mode 100644 ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu create mode 100644 ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu diff --git a/Makefile b/Makefile index 94cfdab4e..278a06d2f 100644 --- a/Makefile +++ b/Makefile @@ -508,7 +508,7 @@ define NVCC_COMPILE endef # NVCC_COMPILE endif # JETSON_EOL_MODULE_DETECT -ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh +ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh $(NVCC_COMPILE) ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index b9703dac2..ab840be4d 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -49,7 +49,7 @@ typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { #if __CUDA_ARCH__ > MIN_CC_DP4A @@ -263,7 +263,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( } template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { #if __CUDA_ARCH__ > MIN_CC_DP4A @@ -304,7 +304,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( } template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; @@ -393,7 +393,7 @@ typedef half (*dequantize_1_f16_t)(const void *, const int64_t); typedef float (*dequantize_1_f32_t)(const void *, const int64_t); template -static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { const block_q4_0 * x = (const block_q4_0 *) vx; const int64_t ib = i / QK4_0; @@ -485,7 +485,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ } template -static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { const block_q8_0 * x = (const block_q8_0 *) vx; const int64_t ib = i / QK8_0; @@ -504,7 +504,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ } template -static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { const half * x = (const half *) vx; return x[i]; @@ -669,3 +669,369 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); + + const half * maskh = (const half *) mask + ne11*ic0; + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[ncols*D]; + half2 * KQ2 = (half2 *) KQ; + + half kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -HALF_MAX_HALF; + } + half kqsum[ncols] = {0.0f}; + + __shared__ half kqmax_shared[ncols][WARP_SIZE]; + __shared__ half kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + } + + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -HALF_MAX_HALF; + } + + half2 VKQ[ncols] = {{0.0f, 0.0f}}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { + break; + } + + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +#define DECL_FATTN_VEC_F16_INST(D, ncols, parallel_blocks, vec_dot_KQ, Q_q8_1, dequantize_1_v) \ + template __global__ void flash_attn_vec_ext_f16( \ + const char * __restrict__ Q, \ + const char * __restrict__ K, \ + const char * __restrict__ V, \ + const char * __restrict__ mask, \ + float * __restrict__ dst, \ + float2 * __restrict__ dst_meta, \ + const float scale, \ + const float max_bias, \ + const float m0, \ + const float m1, \ + const uint32_t n_head_log2, \ + const int ne00, \ + const int ne01, \ + const int ne02, \ + const int ne03, \ + const int ne10, \ + const int ne11, \ + const int ne12, \ + const int ne13, \ + const int ne31, \ + const int nb31, \ + const int nb01, \ + const int nb02, \ + const int nb03, \ + const int nb11, \ + const int nb12, \ + const int nb13, \ + const int nb21, \ + const int nb22, \ + const int nb23, \ + const int ne0, \ + const int ne1, \ + const int ne2, \ + const int ne3) + + +extern DECL_FATTN_VEC_F16_INST(64, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +extern DECL_FATTN_VEC_F16_INST(128, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +extern DECL_FATTN_VEC_F16_INST(256, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); + +#define DECL_FATTN_VEC_INST(type_VKQ, D, cols_per_block, parallel_blocks, type_suffix_K, type_suffix_V) \ + template __global__ void flash_attn_vec_ext_##type_VKQ< \ + (D), cols_per_block, parallel_blocks, \ + vec_dot_fattn_vec_KQ_##type_suffix_K, ggml_type_##type_suffix_K != GGML_TYPE_F16, dequantize_1_##type_suffix_V>( \ + const char * __restrict__ Q, \ + const char * __restrict__ K, \ + const char * __restrict__ V, \ + const char * __restrict__ mask, \ + float * __restrict__ dst, \ + float2 * __restrict__ dst_meta, \ + const float scale, \ + const float max_bias, \ + const float m0, \ + const float m1, \ + const uint32_t n_head_log2, \ + const int ne00, \ + const int ne01, \ + const int ne02, \ + const int ne03, \ + const int ne10, \ + const int ne11, \ + const int ne12, \ + const int ne13, \ + const int ne31, \ + const int nb31, \ + const int nb01, \ + const int nb02, \ + const int nb03, \ + const int nb11, \ + const int nb12, \ + const int nb13, \ + const int nb21, \ + const int nb22, \ + const int nb23, \ + const int ne0, \ + const int ne1, \ + const int ne2, \ + const int ne3) + + +extern DECL_FATTN_VEC_INST(f16, 128, 1, 4, q4_0, q4_0); +extern DECL_FATTN_VEC_INST(f16, 128, 1, 4, q8_0, q8_0); diff --git a/ggml-cuda/fattn-vec-f16-f16.cu b/ggml-cuda/fattn-vec-f16-f16.cu new file mode 100644 index 000000000..557b561fb --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-f16.cu @@ -0,0 +1,7 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_INST(64, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +DECL_FATTN_VEC_F16_INST(128, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +DECL_FATTN_VEC_F16_INST(256, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); diff --git a/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu b/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu new file mode 100644 index 000000000..ea1b8ba2b --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu @@ -0,0 +1,5 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_INST(f16, 128, 1, 4, q4_0, q4_0); diff --git a/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu b/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu new file mode 100644 index 000000000..773575746 --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_INST(f16, 128, 1, 4, q8_0, q8_0); diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index c427e18ab..eac498c2d 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -2,287 +2,6 @@ #include "fattn-common.cuh" #include "fattn-vec-f16.cuh" -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_vec_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#if FP16_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); - - const half * maskh = (const half *) mask + ne11*ic0; - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ half KQ[ncols*D]; - half2 * KQ2 = (half2 *) KQ; - - half kqmax[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -HALF_MAX_HALF; - } - half kqsum[ncols] = {0.0f}; - - __shared__ half kqmax_shared[ncols][WARP_SIZE]; - __shared__ half kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - __syncthreads(); - - // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - } - - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } - - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { - // Calculate KQ tile and keep track of new maximum KQ values: - - // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, - // see https://github.com/ggerganov/llama.cpp/pull/7061 . - // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). - half kqmax_new = kqmax[0]; - half kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - if (ncols == 1) { - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - } else { - kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); - } - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - - kqmax_new_j = warp_reduce_max(kqmax_new_j); - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } - - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - - half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (parallel_blocks == 1) { - dst_val /= kqsum[j_VKQ]; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; - } - - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE -} - void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * KQV = dst; ggml_tensor * Q = dst->src[0];