cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels

This commit is contained in:
Georgi Gerganov 2023-12-05 16:47:34 +02:00
parent e8457c90a0
commit b2acedeb1a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -7,6 +7,7 @@
#include <stdio.h> #include <stdio.h>
#include <atomic> #include <atomic>
#include <assert.h> #include <assert.h>
#include <float.h>
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
block_q4_0 * dsti = (block_q4_0 *) cdsti; block_q4_0 * dsti = (block_q4_0 *) cdsti;
float amax = 0.0f; float amax = 0.0f;
float max = 0.0f; float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) { for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j]; const float v = xi[j];
if (amax < fabsf(v)) { if (amax < fabsf(v)) {
amax = fabsf(v); amax = fabsf(v);
max = v; vmax = v;
} }
} }
const float d = max / -8; const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = d; dsti->d = d;
for (int j = 0; j < QK4_0/2; ++j) { for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id; const float x0 = xi[0 + j]*id;
@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
} }
} }
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,