ggml : use full range for Q4_0 and Q4_2 quantization (#729)

* Use full range for q4_0 quantization

By keeping the sign of the highest magnitude, we can make sure the
highest value maps to -8, which is currently unused.
This is a bit of a freebie since it is fully backwards compatible with
the current format.

* Update quantize_row_q4_0 for AVX/AVX2

* Update quantize_row_q4_0 for WASM

Untested

* Update quantize_row_q4_0 for Arm NEON

* Update quantize_row_q4_0 for PowerPC

Untested

* Use full range for q4_2 quantization
This commit is contained in:
unbounded 2023-04-25 19:20:46 +02:00 committed by GitHub
parent 54bb60e268
commit dd0eabc049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

206
ggml.c
View File

@ -692,13 +692,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f;
for (int l = 0; l < QK4_0; l++) { for (int l = 0; l < QK4_0; l++) {
const float v = x[i*QK4_0 + l]; const float v = x[i*QK4_0 + l];
amax = MAX(amax, fabsf(v)); if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
} }
const float d = amax / ((1 << 3) - 1); const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = d; y[i].d = d;
@ -707,8 +711,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
const float v0 = x[i*QK4_0 + l + 0]*id; const float v0 = x[i*QK4_0 + l + 0]*id;
const float v1 = x[i*QK4_0 + l + 1]*id; const float v1 = x[i*QK4_0 + l + 1]*id;
const uint8_t vi0 = (int8_t)roundf(v0) + 8; const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
const uint8_t vi1 = (int8_t)roundf(v1) + 8; const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
assert(vi0 < 16); assert(vi0 < 16);
assert(vi1 < 16); assert(vi1 < 16);
@ -728,28 +732,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
#if defined(__POWER9_VECTOR__) #if defined(__POWER9_VECTOR__)
const vector float v85 = vec_splats(8.5f); const vector float v85 = vec_splats(8.5f);
const vector signed int v15 = vec_splats(15);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float max = 0.0f;
float min = 0.0f;
vector float srcv [8]; vector float srcv [8];
vector float asrcv[8]; vector float maxv[8];
vector float amaxv[8]; vector float minv[8];
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
//for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]); //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
amaxv[0] = vec_max(amaxv[0], amaxv[2]); maxv[0] = vec_max(maxv[0], maxv[2]);
amaxv[4] = vec_max(amaxv[4], amaxv[6]); maxv[4] = vec_max(maxv[4], maxv[6]);
//for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]); //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
amaxv[0] = vec_max(amaxv[0], amaxv[4]); maxv[0] = vec_max(maxv[0], maxv[4]);
amax = MAX( for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)), //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3))); minv[0] = vec_min(minv[0], minv[2]);
minv[4] = vec_min(minv[4], minv[6]);
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
minv[0] = vec_min(minv[0], minv[4]);
const float d = amax / ((1 << 3) - 1);
max = MAX(
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
min = MIN(
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
y[i].d = d; y[i].d = d;
@ -759,27 +777,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
for (int l = 0; l < 8; l++) { for (int l = 0; l < 8; l++) {
const vector float vf = vec_madd(srcv[l], vid, v85); const vector float vf = vec_madd(srcv[l], vid, v85);
const vector signed int vi = vec_signed(vf); const vector signed int vi = vec_signed(vf);
const vector signed int vc = vec_min(vi, v15);
pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
} }
} }
#elif __ARM_NEON #elif __ARM_NEON
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float32x4_t srcv [8]; float32x4_t srcv [8];
float32x4_t asrcv[8]; float32x4_t maxv[8];
float32x4_t amaxv[8]; float32x4_t minv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
const float amax = vmaxvq_f32(amaxv[0]); for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
const float d = amax / ((1 << 3) - 1); const float max = vmaxvq_f32(maxv[0]);
const float min = vminvq_f32(minv[0]);
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = d; y[i].d = d;
@ -788,9 +812,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const float32x4_t v = vmulq_n_f32(srcv[l], id); const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
const int32x4_t vi = vcvtq_s32_f32(vf); const int32x4_t vi = vcvtq_s32_f32(vf);
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
} }
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
@ -802,22 +827,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 ); __m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32; x += 32;
// Compute max(abs(e)) for the block // Compute max for the block
const __m256 signBit = _mm256_set1_ps( -0.0f ); __m256 max = _mm256_max_ps( v0, v1 );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); __m256 maxTmp = _mm256_max_ps( v2, v3 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); max = _mm256_max_ps( max, maxTmp );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 ); const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats // Quantize these floats
const float d = maxScalar / 7.0f; const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d; y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id ); const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier // Apply the multiplier
@ -850,9 +884,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm ); i0 = _mm256_permutevar8x32_epi32( i0, perm );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m256i off = _mm256_set1_epi8( 8 ); const __m256i off = _mm256_set1_epi8( 8 );
i0 = _mm256_add_epi8( i0, off ); i0 = _mm256_add_epi8( i0, off );
const __m256i maxNibble = _mm256_set1_epi8( 15 );
i0 = _mm256_min_epi8( i0, maxNibble );
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 ); __m128i res = packNibbles( i0 );
@ -867,22 +903,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 ); __m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32; x += 32;
// Compute max(abs(e)) for the block // Compute max for the block
const __m256 signBit = _mm256_set1_ps( -0.0f ); __m256 max = _mm256_max_ps( v0, v1 );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); __m256 maxTmp = _mm256_max_ps( v2, v3 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); max = _mm256_max_ps( max, maxTmp );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 ); const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats // Quantize these floats
const float d = maxScalar / 7.0f; const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d; y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id ); const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier // Apply the multiplier
@ -923,10 +968,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
ni0 = _mm_packs_epi16( ni0, ni2 ); ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 ); ni4 = _mm_packs_epi16( ni4, ni6 );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m128i off = _mm_set1_epi8( 8); const __m128i off = _mm_set1_epi8( 8 );
ni0 = _mm_add_epi8( ni0, off ); ni0 = _mm_add_epi8( ni0, off );
ni4 = _mm_add_epi8( ni4, off ); ni4 = _mm_add_epi8( ni4, off );
const __m128i maxNibble = _mm_set1_epi8( 15 );
ni0 = _mm_min_epi8( ni0, maxNibble );
ni4 = _mm_min_epi8( ni4, maxNibble );
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( ni0, ni4 ); __m128i res = packNibbles( ni0, ni4 );
@ -934,24 +982,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
} }
#elif defined(__wasm_simd128__) #elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float max = 0.0f;
float min = 0.0f;
v128_t srcv [8]; v128_t srcv [8];
v128_t asrcv[8]; v128_t maxv[8];
v128_t amaxv[8]; v128_t minv[8];
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
amax = MAX( for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
const float d = amax / ((1 << 3) - 1); max = MAX(
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
min = MIN(
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
y[i].d = d; y[i].d = d;
@ -960,9 +1016,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
} }
} }
#else #else
@ -1143,13 +1200,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f;
for (int l = 0; l < QK4_2; l++) { for (int l = 0; l < QK4_2; l++) {
const float v = x[i*QK4_2 + l]; const float v = x[i*QK4_2 + l];
amax = MAX(amax, fabsf(v)); if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
} }
const float d = amax / ((1 << 3) - 1); const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
@ -1159,8 +1220,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
const float v0 = x[i*QK4_2 + l + 0]*id; const float v0 = x[i*QK4_2 + l + 0]*id;
const float v1 = x[i*QK4_2 + l + 1]*id; const float v1 = x[i*QK4_2 + l + 1]*id;
const uint8_t vi0 = (uint8_t)(v0 + 8.5f); const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
const uint8_t vi1 = (uint8_t)(v1 + 8.5f); const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
assert(vi0 < 16); assert(vi0 < 16);
assert(vi1 < 16); assert(vi1 < 16);
@ -1254,9 +1315,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
block_q4_2 * restrict y = vy; block_q4_2 * restrict y = vy;
//quantize_row_q4_2_reference(x, y, k); quantize_row_q4_2_reference(x, y, k);
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
quantize_row_q4_2_rmse(x, y, k);
} }
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) { static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@ -1807,7 +1866,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_2] = { [GGML_TYPE_Q4_2] = {
.dequantize_row_q = dequantize_row_q4_2, .dequantize_row_q = dequantize_row_q4_2,
.quantize_row_q = quantize_row_q4_2, .quantize_row_q = quantize_row_q4_2,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
.quantize_row_q_dot = quantize_row_q8_0, .quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_q4_2_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
}, },
@ -12144,8 +12203,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
for (int j = 0; j < n; j += k) { for (int j = 0; j < n; j += k) {
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2; block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
//quantize_row_q4_2_reference(src + j, y, k); quantize_row_q4_2_reference(src + j, y, k);
quantize_row_q4_2_rmse(src + j, y, k);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
for (int l = 0; l < QK4_2; l += 2) { for (int l = 0; l < QK4_2; l += 2) {