mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
48 lines
1.8 KiB
Plaintext
48 lines
1.8 KiB
Plaintext
|
#define FATTN_KQ_STRIDE 256
|
||
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||
|
|
||
|
template<int D, int parallel_blocks> // D == head size
|
||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||
|
__launch_bounds__(D, 1)
|
||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||
|
static __global__ void flash_attn_combine_results(
|
||
|
const float * __restrict__ VKQ_parts,
|
||
|
const float2 * __restrict__ VKQ_meta,
|
||
|
float * __restrict__ dst) {
|
||
|
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||
|
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||
|
dst += D * gridDim.y*blockIdx.x;
|
||
|
|
||
|
const int tid = threadIdx.x;
|
||
|
__builtin_assume(tid < D);
|
||
|
|
||
|
__shared__ float2 meta[parallel_blocks];
|
||
|
if (tid < 2*parallel_blocks) {
|
||
|
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
||
|
}
|
||
|
|
||
|
__syncthreads();
|
||
|
|
||
|
float kqmax = meta[0].x;
|
||
|
#pragma unroll
|
||
|
for (int l = 1; l < parallel_blocks; ++l) {
|
||
|
kqmax = max(kqmax, meta[l].x);
|
||
|
}
|
||
|
|
||
|
float VKQ_numerator = 0.0f;
|
||
|
float VKQ_denominator = 0.0f;
|
||
|
#pragma unroll
|
||
|
for (int l = 0; l < parallel_blocks; ++l) {
|
||
|
const float diff = meta[l].x - kqmax;
|
||
|
const float KQ_max_scale = expf(diff);
|
||
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||
|
|
||
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||
|
}
|
||
|
|
||
|
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||
|
}
|