mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
Add ability to use Q5_0, Q5_1, and IQ4_NL for quantized K cache (#6183)
* k_cache: be able to use Q5_0 * k_cache: be able to use Q5_1 on CODA * k_cache: be able to use Q5_0 on Metal * k_cache: be able to use Q5_1 on Metal * k_cache: be able to use IQ4_NL - just CUDA for now * k_cache: be able to use IQ4_NL on Metal * k_cache: add newly added supported types to llama-bench and CUDA supports_op --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
c5b8595e3f
commit
76aa30a263
@ -1590,6 +1590,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||||||
if (s == "q4_1") {
|
if (s == "q4_1") {
|
||||||
return GGML_TYPE_Q4_1;
|
return GGML_TYPE_Q4_1;
|
||||||
}
|
}
|
||||||
|
if (s == "iq4_nl") {
|
||||||
|
return GGML_TYPE_IQ4_NL;
|
||||||
|
}
|
||||||
if (s == "q5_0") {
|
if (s == "q5_0") {
|
||||||
return GGML_TYPE_Q5_0;
|
return GGML_TYPE_Q5_0;
|
||||||
}
|
}
|
||||||
|
@ -249,6 +249,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
|
|||||||
if (s == "q5_1") {
|
if (s == "q5_1") {
|
||||||
return GGML_TYPE_Q5_1;
|
return GGML_TYPE_Q5_1;
|
||||||
}
|
}
|
||||||
|
if (s == "iq4_nl") {
|
||||||
|
return GGML_TYPE_IQ4_NL;
|
||||||
|
}
|
||||||
|
|
||||||
return GGML_TYPE_COUNT;
|
return GGML_TYPE_COUNT;
|
||||||
}
|
}
|
||||||
|
165
ggml-cuda.cu
165
ggml-cuda.cu
@ -6757,6 +6757,123 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q5_0 * dsti = (block_q5_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK5_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -16;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dsti->d = d;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_0/2; ++j) {
|
||||||
|
const float x0 = xi[0 + j]*id;
|
||||||
|
const float x1 = xi[QK5_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
||||||
|
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||||
|
}
|
||||||
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q5_1 * dsti = (block_q5_1 *) cdsti;
|
||||||
|
|
||||||
|
float min = xi[0];
|
||||||
|
float max = xi[0];
|
||||||
|
|
||||||
|
for (int j = 1; j < QK5_1; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
min = v < min ? v : min;
|
||||||
|
max = v > max ? v : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (max - min) / 31;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dsti->dm.x = d;
|
||||||
|
dsti->dm.y = min;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_1/2; ++j) {
|
||||||
|
const float x0 = (xi[0 + j] - min)*id;
|
||||||
|
const float x1 = (xi[QK5_1/2 + j] - min)*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||||
|
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||||
|
|
||||||
|
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||||
|
}
|
||||||
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
|
if (x <= val[0]) return 0;
|
||||||
|
if (x >= val[n-1]) return n-1;
|
||||||
|
int ml = 0, mu = n-1;
|
||||||
|
while (mu-ml > 1) {
|
||||||
|
int mav = (ml+mu)/2;
|
||||||
|
if (x < val[mav]) mu = mav; else ml = mav;
|
||||||
|
}
|
||||||
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_NL; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = vmax / kvalues_iq4nl[0];
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
float sumqx = 0, sumq2 = 0;
|
||||||
|
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||||
|
const float x0 = xi[0 + j]*id;
|
||||||
|
const float x1 = xi[QK4_NL/2 + j]*id;
|
||||||
|
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||||
|
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||||
|
dsti->qs[j] = xi0 | (xi1 << 4);
|
||||||
|
const float v0 = kvalues_iq4nl[xi0];
|
||||||
|
const float v1 = kvalues_iq4nl[xi1];
|
||||||
|
const float w0 = xi[0 + j]*xi[0 + j];
|
||||||
|
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
|
||||||
|
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
|
||||||
|
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_blck, int qk>
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -8490,6 +8607,39 @@ static void ggml_cpy_f32_q4_1_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q5_0_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK5_0 == 0);
|
||||||
|
const int num_blocks = ne / QK5_0;
|
||||||
|
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q5_1_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK5_1 == 0);
|
||||||
|
const int num_blocks = ne / QK5_1;
|
||||||
|
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK4_NL == 0);
|
||||||
|
const int num_blocks = ne / QK4_NL;
|
||||||
|
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f16_f16_cuda(
|
static void ggml_cpy_f16_f16_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -10888,6 +11038,12 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
|
|||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
|
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
|
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
|
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
@ -11304,6 +11460,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
28
ggml-metal.m
28
ggml-metal.m
@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||||
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||||
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||||
@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||||
@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -2431,13 +2436,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||||
|
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||||
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||||
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||||
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
240
ggml-metal.metal
240
ggml-metal.metal
@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f32_q5_0(
|
||||||
|
device const float * src0,
|
||||||
|
device void * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
|
||||||
|
|
||||||
|
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
||||||
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
float max = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK5_0; j++) {
|
||||||
|
const float v = src[j];
|
||||||
|
if (amax < fabs(v)) {
|
||||||
|
amax = fabs(v);
|
||||||
|
max = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = max / -16;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dst_data[i00/QK5_0].d = d;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_0/2; ++j) {
|
||||||
|
const float x0 = src[0 + j]*id;
|
||||||
|
const float x1 = src[QK5_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||||
|
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||||
|
|
||||||
|
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||||
|
}
|
||||||
|
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f32_q5_1(
|
||||||
|
device const float * src0,
|
||||||
|
device void * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
||||||
|
|
||||||
|
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
|
||||||
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
|
float max = src[0];
|
||||||
|
float min = src[0];
|
||||||
|
|
||||||
|
for (int j = 1; j < QK5_1; j++) {
|
||||||
|
const float v = src[j];
|
||||||
|
min = v < min ? v : min;
|
||||||
|
max = v > max ? v : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (max - min) / 31;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dst_data[i00/QK5_1].d = d;
|
||||||
|
dst_data[i00/QK5_1].m = min;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_1/2; ++j) {
|
||||||
|
const float x0 = (src[0 + j] - min)*id;
|
||||||
|
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||||
|
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||||
|
|
||||||
|
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||||
|
}
|
||||||
|
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||||
|
if (x <= val[0]) return 0;
|
||||||
|
if (x >= val[n-1]) return n-1;
|
||||||
|
int ml = 0, mu = n-1;
|
||||||
|
while (mu-ml > 1) {
|
||||||
|
int mav = (ml+mu)/2;
|
||||||
|
if (x < val[mav]) mu = mav; else ml = mav;
|
||||||
|
}
|
||||||
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||||
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||||
|
};
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f32_iq4_nl(
|
||||||
|
device const float * src0,
|
||||||
|
device void * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
|
||||||
|
|
||||||
|
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
|
||||||
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
float max = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0; j++) {
|
||||||
|
const float v = src[j];
|
||||||
|
if (amax < fabs(v)) {
|
||||||
|
amax = fabs(v);
|
||||||
|
max = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = max / kvalues_iq4nl_f[0];
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
float sumqx = 0, sumq2 = 0;
|
||||||
|
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||||
|
const float x0 = src[0 + j]*id;
|
||||||
|
const float x1 = src[QK4_NL/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||||
|
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||||
|
|
||||||
|
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
||||||
|
|
||||||
|
const float v0 = kvalues_iq4nl_f[xi0];
|
||||||
|
const float v1 = kvalues_iq4nl_f[xi1];
|
||||||
|
const float w0 = src[0 + j]*src[0 + j];
|
||||||
|
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||||
|
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||||
|
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_concat(
|
kernel void kernel_concat(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
||||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
||||||
};
|
|
||||||
|
|
||||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
Loading…
Reference in New Issue
Block a user