From b2acedeb1a2f7440426d50bc4c01b5a3ea82bd76 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:47:34 +0200 Subject: [PATCH] cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels --- ggml-cuda.cu | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c2ce1769c..53e53a0d1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7,6 +7,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { block_q4_0 * dsti = (block_q4_0 *) cdsti; float amax = 0.0f; - float max = 0.0f; + float vmax = 0.0f; for (int j = 0; j < QK4_0; ++j) { const float v = xi[j]; if (amax < fabsf(v)) { amax = fabsf(v); - max = v; + vmax = v; } } - const float d = max / -8; + const float d = vmax / -8; const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + dsti->d = d; for (int j = 0; j < QK4_0/2; ++j) { const float x0 = xi[0 + j]*id; @@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { } } +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_1 * dsti = (block_q4_1 *) cdsti; + + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = xi[j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->dm.x = d; + dsti->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (xi[0 + j] - vmin)*id; + const float x1 = (xi[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + template static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,