From a710d58d8867c0b494023b799bfd41573f6a439d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 21 Mar 2024 20:18:50 +0200 Subject: [PATCH] Try fix quantized k-cache on ROCm --- ggml-cuda.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 04c6f5d07..f903c5f07 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6684,7 +6684,7 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - dsti->d = d; + dsti->d = __float2half(d); for (int j = 0; j < QK8_0; ++j) { const float x0 = xi[j]*id; @@ -6711,7 +6711,7 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { const float d = vmax / -8; const float id = d ? 1.0f/d : 0.0f; - dsti->d = d; + dsti->d = __float2half(d); for (int j = 0; j < QK4_0/2; ++j) { const float x0 = xi[0 + j]*id; @@ -6742,8 +6742,8 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { const float d = (vmax - vmin) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - dsti->dm.x = d; - dsti->dm.y = vmin; + dsti->dm.x = __float2half(d); + dsti->dm.y = __float2half(vmin); for (int j = 0; j < QK4_1/2; ++j) { const float x0 = (xi[0 + j] - vmin)*id; @@ -6775,7 +6775,7 @@ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { const float d = vmax / -16; const float id = d ? 1.0f/d : 0.0f; - dsti->d = d; + dsti->d = __float2half(d); uint32_t qh = 0; for (int j = 0; j < QK5_0/2; ++j) { @@ -6808,8 +6808,8 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { const float d = (max - min) / 31; const float id = d ? 1.0f/d : 0.0f; - dsti->dm.x = d; - dsti->dm.y = min; + dsti->dm.x = __float2half(d); + dsti->dm.y = __float2half(min); uint32_t qh = 0; for (int j = 0; j < QK5_1/2; ++j) { @@ -6870,7 +6870,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { sumq2 += w0*v0*v0 + w1*v1*v1; } - dsti->d = sumq2 > 0 ? sumqx/sumq2 : d; + dsti->d = __float2half(sumq2 > 0 ? sumqx/sumq2 : d); }