diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 89bbc9749..b9703dac2 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -404,7 +404,13 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template @@ -444,7 +450,13 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template @@ -482,7 +494,13 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const T d = x[ib].d; const int q = x[ib].qs[iqs]; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template