Q2_K: fixed bug in imatrix quantization for QK_K = 64

This commit is contained in:
Iwan Kawrakow 2024-02-28 08:15:52 +02:00
parent 2540a290ed
commit 47d52b2b24

View File

@ -1877,7 +1877,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
float mins[QK_K/16]; float mins[QK_K/16];
float scales[QK_K/16]; float scales[QK_K/16];
float sw[QK_K/16]; float sw[QK_K/16];
float weight[QK_K/16]; float weight[16];
uint8_t Ls[QK_K/16], Lm[QK_K/16]; uint8_t Ls[QK_K/16], Lm[QK_K/16];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
@ -1887,13 +1887,42 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
float sigma2 = sumx2/QK_K; float sigma2 = sumx2/QK_K;
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
const float * restrict qw = quant_weights + QK_K * i + 16*j; const float * restrict qw = quant_weights + QK_K * i + 16*j;
for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
} }
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); float dm, mm;
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); #if QK_K == 64
float max_scale = 0, max_min = 0;
for (int j = 0; j < QK_K/16; ++j) {
max_scale = MAX(max_scale, scales[j]);
max_min = MAX(max_min, mins[j]);
}
dm = max_scale/15;
mm = max_min/15;
if (max_scale) {
float id = 1/dm;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(id*scales[j]);
Ls[j] = MAX(0, MIN(15, l));
}
} else {
memset(Ls, 0, QK_K/16);
}
if (max_min) {
float id = 1/mm;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(id*mins[j]);
Lm[j] = MAX(0, MIN(15, l));
}
} else {
memset(Lm, 0, QK_K/16);
}
#else
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
#endif
y[i].d = GGML_FP32_TO_FP16(dm); y[i].d = GGML_FP32_TO_FP16(dm);
y[i].dmin = GGML_FP32_TO_FP16(mm); y[i].dmin = GGML_FP32_TO_FP16(mm);
dm = GGML_FP16_TO_FP32(y[i].d); dm = GGML_FP16_TO_FP32(y[i].d);
@ -6310,7 +6339,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
float sumf = 0; float sumf = 0;
int isum[4]; int isum[QK_K/16];
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -6326,14 +6355,14 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
isum[0] = isum[1] = isum[2] = isum[3] = 0; memset(isum, 0, (QK_K/16)*sizeof(int));
for (int l = 0; l < 16; ++l) { for (int l = 0; l < 16; ++l) {
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
} }
for (int l = 0; l < 4; ++l) { for (int l = 0; l < QK_K/16; ++l) {
isum[l] *= (sc[l] & 0xF); isum[l] *= (sc[l] & 0xF);
} }
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;