HIP: fix flash_attn_stream_k_fixup warning (#11604)

This commit is contained in:
Johannes Gäßler 2025-02-02 23:48:29 +01:00 committed by GitHub
parent 396856b400
commit 6eecde3cc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 2 deletions

View File

@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr; nullptr;
} }
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup(
} }
} }
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)

View File

@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed" #pragma clang diagnostic ignored "-Wpass-failed"
#endif #endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T> template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32( static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
} }
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif // __clang__
static __global__ void soft_max_back_f32( static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {