mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
HIP: fix flash_attn_stream_k_fixup warning (#11604)
This commit is contained in:
parent
396856b400
commit
6eecde3cc8
@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|||||||
nullptr;
|
nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
|
||||||
|
#ifdef __clang__
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
template<int D, int ncols, int KQ_stride> // D == head size
|
template<int D, int ncols, int KQ_stride> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __clang__
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int parallel_blocks> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
|
@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
#pragma clang diagnostic push
|
#pragma clang diagnostic push
|
||||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
#endif
|
#endif // __clang__
|
||||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(
|
static __global__ void soft_max_f32(
|
||||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
||||||
@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
|
|||||||
}
|
}
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
#pragma clang diagnostic pop
|
#pragma clang diagnostic pop
|
||||||
#endif
|
#endif // __clang__
|
||||||
|
|
||||||
static __global__ void soft_max_back_f32(
|
static __global__ void soft_max_back_f32(
|
||||||
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
||||||
|
Loading…
Reference in New Issue
Block a user