ggml : introduce structs for the q4 data blocks (#356)

* Introduce structs for the q4 data blocks

* ggml : rename quant struct variables + fix ARM_NEON

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Stephan Walter 2023-03-28 15:56:03 +00:00 committed by GitHub
parent e0670260fb
commit c1f885067c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 150 additions and 235 deletions

View File

@ -4,8 +4,6 @@
#include <cstdio> #include <cstdio>
#include <string> #include <string>
const int QK = 32;
// usage: // usage:
// ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type // ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
// //
@ -39,7 +37,7 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype, QK)) { if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1; return 1;
} }

359
ggml.c
View File

@ -448,17 +448,27 @@ static inline __m128i packNibbles( __m256i bytes )
// method 5 // method 5
// blocks of QK elements // blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
typedef struct {
float d; // delta
uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
// method 4
// blocks of QK elements
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
typedef struct {
float d;
float m;
uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) { static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
const int nb = k / QK; const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
uint8_t pp[QK/2]; uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
@ -472,8 +482,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
const float d = amax / ((1 << 3) - 1); const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
*(float *)pd = d; y[i].d = d;
pd += bs;
for (int l = 0; l < QK; l += 2) { for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*QK + l + 0]*id; const float v0 = x[i*QK + l + 0]*id;
@ -488,23 +497,15 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
pp[l/2] = vi0 | (vi1 << 4); pp[l/2] = vi0 | (vi1 << 4);
} }
memcpy(pb, pp, sizeof(pp)); memcpy(y[i].qs, pp, sizeof(pp));
pb += bs;
} }
} }
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0); assert(k % QK == 0);
#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
const int nb = k / QK; const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
uint8_t * restrict pd = ((uint8_t *)y + 0*bs); block_q4_0 * restrict y = vy;
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
uint8_t pp[QK/2];
#endif
#if defined(__POWER9_VECTOR__) #if defined(__POWER9_VECTOR__)
const vector float v85 = vec_splats(8.5f); const vector float v85 = vec_splats(8.5f);
@ -532,10 +533,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1); const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
*(float *)pd = d; y[i].d = d;
pd += bs;
const vector float vid = vec_splats(id); const vector float vid = vec_splats(id);
uint8_t * restrict pb = y[i].qs;
for (int l = 0; l < 8; l++) { for (int l = 0; l < 8; l++) {
const vector float vf = vec_madd(srcv[l], vid, v85); const vector float vf = vec_madd(srcv[l], vid, v85);
const vector signed int vi = vec_signed(vf); const vector signed int vi = vec_signed(vf);
@ -543,11 +544,9 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
} }
//memcpy(pb, pp, sizeof(pp));
pb += bs;
} }
#elif __ARM_NEON #elif __ARM_NEON
uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
@ -569,8 +568,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1); const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
*(float *)pd = d; y[i].d = d;
pd += bs;
for (int l = 0; l < 8; l++) { for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id); const float32x4_t v = vmulq_n_f32(srcv[l], id);
@ -581,8 +579,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
} }
memcpy(pb, pp, sizeof(pp)); memcpy(y[i].qs, pp, sizeof(pp));
pb += bs;
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
@ -607,8 +604,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
// Quantize these floats // Quantize these floats
const float d = maxScalar / 7.0f; const float d = maxScalar / 7.0f;
*(float *)pd = d; y[i].d = d;
pd += bs;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id ); const __m256 mul = _mm256_set1_ps( id );
@ -648,10 +644,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 ); __m128i res = packNibbles( i0 );
_mm_storeu_si128( ( __m128i* )pb, res ); _mm_storeu_si128( ( __m128i* )y[i].qs, res );
pb += bs;
} }
#elif defined(__wasm_simd128__) #elif defined(__wasm_simd128__)
uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
@ -673,8 +669,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1); const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
*(float *)pd = d; y[i].d = d;
pd += bs;
for (int l = 0; l < 8; l++) { for (int l = 0; l < 8; l++) {
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
@ -685,8 +680,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
} }
memcpy(pb, pp, sizeof(pp)); memcpy(y[i].qs, pp, sizeof(pp));
pb += bs;
} }
#else #else
// scalar // scalar
@ -694,18 +688,11 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
#endif #endif
} }
// method 4 static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
// blocks of QK elements
// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
const int nb = k / QK; const int nb = k / QK;
const size_t bs = 2*sizeof(float) + QK/2;
uint8_t * restrict pd = ((uint8_t *)y + 0*bs); block_q4_1 * restrict y = vy;
uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
uint8_t pp[QK/2]; uint8_t pp[QK/2];
@ -722,10 +709,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
const float d = (max - min) / ((1 << 4) - 1); const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
*(float *)pm = min; y[i].d = d;
*(float *)pd = d; y[i].m = min;
pm += bs;
pd += bs;
for (int l = 0; l < QK; l += 2) { for (int l = 0; l < QK; l += 2) {
const float v0 = (x[i*QK + l + 0] - min)*id; const float v0 = (x[i*QK + l + 0] - min)*id;
@ -740,27 +725,22 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4); pp[l/2] = vi0 | (vi1 << 4);
} }
memcpy(pb, pp, sizeof(pp)); memcpy(y[i].qs, pp, sizeof(pp));
pb += bs;
} }
} }
// TODO: vectorize static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
const int nb = k / QK; const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); const block_q4_0 * restrict x = vx;
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
#if defined(__AVX2__) #if defined(__AVX2__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
// scale factor // scale factor
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const uint8_t * restrict pp = pb + i*bs; const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 32) { for (int l = 0; l < QK; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers // Load 32x4-bit integers into 32x8-bit integers
@ -790,17 +770,15 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
} }
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = *(const float *) (pd + i*bs); const float32x4_t vd = vdupq_n_f32(x[i].d);
const uint8_t * restrict pp = pb + i*bs; const uint8_t * restrict pp = x[i].qs;
const float32x4_t vd = vdupq_n_f32(d);
for (int l = 0; l < QK; l += 16) { for (int l = 0; l < QK; l += 16) {
// Load 16x4-bit integers into 8x8-bit integers // Load 16x4-bit integers into 8x8-bit integers
const uint8x8_t v8 = vld1_u8(pp + l/2); const uint8x8_t v8 = vld1_u8(pp + l/2);
// Expand 4-bit nibbles to 8-bit bytes // Expand 4-bit qs to 8-bit bytes
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
const uint8x8_t v1 = vshr_n_u8(v8, 4); const uint8x8_t v1 = vshr_n_u8(v8, 4);
@ -844,9 +822,9 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
#else #else
// scalar // scalar
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = *(const float *) (pd + i*bs); const float d = x[i].d;
const uint8_t * restrict pp = pb + i*bs; const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
@ -869,22 +847,18 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
#endif #endif
} }
void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
const int nb = k / QK; const int nb = k / QK;
const size_t bs = 2*sizeof(float) + QK/2;
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); const block_q4_1 * restrict x = vx;
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
#if defined(__AVX2__) #if defined(__AVX2__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs)); const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
const uint8_t * restrict pp = pb + i*bs; const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 32) { for (int l = 0; l < QK; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers // Load 32x4-bit integers into 32x8-bit integers
@ -911,10 +885,10 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
} }
#else #else
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = *(const float *) (pd + i*bs); const float d = x[i].d;
const float m = *(const float *) (pm + i*bs); const float m = x[i].m;
const uint8_t * restrict pp = pb + i*bs; const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
@ -1502,25 +1476,15 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
#if __AVX512F__ && QK == 32 #if __AVX512F__ && QK == 32
static inline __m512 dot_q4_0_oneblock_avx512( static inline __m512 dot_q4_0_oneblock_avx512(
__m512 acc, __m512 acc,
const uint8_t * pd0, const block_q4_0 * restrict x,
const uint8_t * pd1, const block_q4_0 * restrict y,
const uint8_t * pb0,
const uint8_t * pb1,
size_t bs,
int i int i
) { ) {
const float * d0_0 = (const float *) (pd0 + i*bs);
const float * d1_0 = (const float *) (pd1 + i*bs);
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
// Compute combined scale for the block // Compute combined scale for the block
float scaleScalar = d0_0[0] * d1_0[0]; __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
__m512 scale = _mm512_set1_ps( scaleScalar );
__m256i bx = bytesFromNibbles( p0 ); __m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( p1 ); __m256i by = bytesFromNibbles( y[i].qs );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 ); const __m256i off = _mm256_set1_epi8( 8 );
@ -1536,7 +1500,7 @@ static inline __m512 dot_q4_0_oneblock_avx512(
// Convert int32_t to float // Convert int32_t to float
__m512 p = _mm512_cvtepi32_ps( i64 ); __m512 p = _mm512_cvtepi32_ps( i64 );
// Apply the scale, and accumulate // Apply the scale, and accumulate
return _mm512_fmadd_ps( scale, p, acc ); return _mm512_fmadd_ps( d, p, acc );
} }
#endif #endif
@ -1576,19 +1540,14 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf; *s = sumf;
} }
inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK; const int nb = n / QK;
assert(n % QK == 0); assert(n % QK == 0);
assert(nb % 2 == 0); assert(nb % 2 == 0);
const size_t bs = sizeof(float) + QK/2; const block_q4_0 * restrict x = vx;
const block_q4_0 * restrict y = vy;
const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + sizeof(float));
float sumf = 0.0; float sumf = 0.0;
@ -1597,23 +1556,18 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f; float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) { for (int i = 0; i < nb; i += 2) {
const float d0_0 = *(const float *) (pd0 + i*bs); const block_q4_0 * restrict x0 = &x[i + 0];
const float d1_0 = *(const float *) (pd1 + i*bs); const block_q4_0 * restrict y0 = &y[i + 0];
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); const block_q4_0 * restrict x1 = &x[i + 1];
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); const block_q4_0 * restrict y1 = &y[i + 1];
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8); const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0); const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v1_0 = vld1q_u8(p1); const uint8x16_t v1_0 = vld1q_u8(y0->qs);
const uint8x16_t v0_1 = vld1q_u8(p0 + bs); const uint8x16_t v0_1 = vld1q_u8(x1->qs);
const uint8x16_t v1_1 = vld1q_u8(p1 + bs); const uint8x16_t v1_1 = vld1q_u8(y1->qs);
// 4-bit -> 8-bit // 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
@ -1651,11 +1605,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
// scalar // scalar
#if defined(__ARM_FEATURE_QRDMX) #if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s32(p_0); sum0 += x0->d * y0->d * vaddvq_s32(p_0);
sum1 += d0_1*d1_1*vaddvq_s32(p_1); sum1 += x1->d * y1->d * vaddvq_s32(p_1);
#else #else
sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
#endif #endif
#else #else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
@ -1681,11 +1635,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
// scalar // scalar
#if defined(__ARM_FEATURE_QRDMX) #if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s16(p_0); sum0 += x0->d * y0->d * vaddvq_s16(p_0);
sum1 += d0_1*d1_1*vaddvq_s16(p_1); sum1 += x1->d * y1->d * vaddvq_s16(p_1);
#else #else
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
#endif #endif
#endif #endif
} }
@ -1703,19 +1657,19 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size; int i = superblock_ix * superblock_size;
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 ); acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 ); acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 ); acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 ); acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 ); acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 ); acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 ); acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 ); acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
} }
// Remainders // Remainders
for (int i = superblock_count * superblock_size; i < nb; ++i) { for (int i = superblock_count * superblock_size; i < nb; ++i) {
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i ); acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
} }
// Horizontal sum of all lanes of the accumulator // Horizontal sum of all lanes of the accumulator
@ -1726,18 +1680,12 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
// Main loop // Main loop
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float * d0_0 = (const float *) (pd0 + i*bs);
const float * d1_0 = (const float *) (pd1 + i*bs);
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
// Compute combined scale for the block // Compute combined scale for the block
const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( d0_0 ), _mm256_broadcast_ss( d1_0 ) ); const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
__m256i bx = bytesFromNibbles( p0 ); __m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( p1 ); __m256i by = bytesFromNibbles( y[i].qs );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 ); const __m256i off = _mm256_set1_epi8( 8 );
@ -1759,7 +1707,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
// Convert int32_t to float // Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 ); __m256 p = _mm256_cvtepi32_ps( i32 );
// Apply the scale, and accumulate // Apply the scale, and accumulate
acc = _mm256_fmadd_ps( scale, p, acc ); acc = _mm256_fmadd_ps( d, p, acc );
} }
// Return horizontal sum of the acc vector // Return horizontal sum of the acc vector
@ -1775,21 +1723,18 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f; float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) { for (int i = 0; i < nb; i += 2) {
const float d0_0 = *(const float *) (pd0 + i*bs); const block_q4_0 * restrict x0 = &px[i + 0];
const float d1_0 = *(const float *) (pd1 + i*bs); const block_q4_0 * restrict y0 = &py[i + 0];
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); const block_q4_0 * restrict x1 = &px[i + 1];
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); const block_q4_0 * restrict y1 = &py[i + 1];
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const v128_t m4b = wasm_u8x16_splat(0xf); const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8); const v128_t s8b = wasm_i8x16_splat(0x8);
const v128_t v0_0 = wasm_v128_load(p0); const v128_t v0_0 = wasm_v128_load(x0.qs);
const v128_t v0_1 = wasm_v128_load(p0 + bs); const v128_t v0_1 = wasm_v128_load(y0.qs);
const v128_t v1_0 = wasm_v128_load(p1); const v128_t v1_0 = wasm_v128_load(x1.qs);
const v128_t v1_1 = wasm_v128_load(p1 + bs); const v128_t v1_1 = wasm_v128_load(y1.qs);
// 4-bit -> 8-bit // 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b); const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@ -1839,12 +1784,12 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
sum0 += d0_0*d1_0*( sum0 += x0->d * y0->d * (
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
sum1 += d0_1*d1_1*( sum1 += x1->d * y1->d * (
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
@ -1855,11 +1800,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else #else
// scalar // scalar
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d0 = *(const float *) (pd0 + i*bs); const float d0 = x[i].d;
const float d1 = *(const float *) (pd1 + i*bs); const float d1 = y[i].d;
const uint8_t * restrict p0 = pb0 + i*bs; const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = pb1 + i*bs; const uint8_t * restrict p1 = y[i].qs;
for (int j = 0; j < QK/2; j++) { for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j]; const uint8_t v0 = p0[j];
@ -1879,19 +1824,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
*s = sumf; *s = sumf;
} }
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK; const int nb = n / QK;
const size_t bs = 2*sizeof(float) + QK/2; const block_q4_1 * restrict x = vx;
const block_q4_1 * restrict y = vy;
const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
float sumf = 0.0; float sumf = 0.0;
@ -1903,21 +1840,17 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
// Main loop // Main loop
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float * m0 = (const float *) (pm0 + i*bs); const float * d0 = &x[i].d;
const float * m1 = (const float *) (pm1 + i*bs); const float * d1 = &y[i].d;
const float * d0 = (const float *) (pd0 + i*bs); const float * m0 = &x[i].m;
const float * d1 = (const float *) (pd1 + i*bs); const float * m1 = &y[i].m;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 ); const __m256 d1v = _mm256_broadcast_ss( d1 );
const __m256 m0v = _mm256_broadcast_ss( m0 ); const __m256 m0v = _mm256_broadcast_ss( m0 );
const __m256 m1v = _mm256_broadcast_ss( m1 ); const __m256 m1v = _mm256_broadcast_ss( m1 );
// Compute combined scale for the block // Compute combined scale for the block
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
@ -1927,8 +1860,8 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 ); const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
__m256i bx = bytesFromNibbles( p0 ); __m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( p1 ); __m256i by = bytesFromNibbles( y[i].qs );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. // Now we have a vector with bytes in [ 0 .. 15 ] interval.
@ -1973,14 +1906,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
#else #else
// scalar // scalar
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float m0 = *(const float *) (pm0 + i*bs); const float d0 = x[i].d;
const float m1 = *(const float *) (pm1 + i*bs); const float d1 = y[i].d;
const float d0 = *(const float *) (pd0 + i*bs); const float m0 = x[i].m;
const float d1 = *(const float *) (pd1 + i*bs); const float m1 = y[i].m;
const uint8_t * restrict p0 = pb0 + i*bs; const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = pb1 + i*bs; const uint8_t * restrict p1 = y[i].qs;
for (int j = 0; j < QK/2; j++) { for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j]; const uint8_t v0 = p0[j];
@ -2251,8 +2184,8 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
sizeof(float ) + QK/2, sizeof(block_q4_0),
sizeof(float )*2 + QK/2, sizeof(block_q4_1),
sizeof(int8_t ), sizeof(int8_t ),
sizeof(int16_t), sizeof(int16_t),
sizeof(int32_t), sizeof(int32_t),
@ -10369,64 +10302,50 @@ enum ggml_opt_result ggml_opt(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist) { size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
const int nb = k / qk; assert(k % QK == 0);
const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2); const int nb = k / QK;
const size_t row_size = nb*bs;
assert(k % qk == 0);
char * pdst = (char *) dst;
for (int j = 0; j < n; j += k) { for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
quantize_row_q4_0_reference(src + j, pd, k); quantize_row_q4_0_reference(src + j, y, k);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
for (int l = 0; l < qk; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi0 = pb[l/2] & 0xF; const uint8_t vi0 = y[i].qs[l/2] & 0xF;
const uint8_t vi1 = pb[l/2] >> 4; const uint8_t vi1 = y[i].qs[l/2] >> 4;
hist[vi0]++; hist[vi0]++;
hist[vi1]++; hist[vi1]++;
} }
pb += bs;
} }
} }
return (n/k)*row_size; return (n/QK*sizeof(block_q4_0));
} }
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist) { size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
const int nb = k / qk; assert(k % QK == 0);
const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2); const int nb = k / QK;
const size_t row_size = nb*bs;
assert(k % qk == 0);
char * pdst = (char *) dst;
for (int j = 0; j < n; j += k) { for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
quantize_row_q4_1(src + j, pd, k); quantize_row_q4_1(src + j, y, k);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
for (int l = 0; l < qk; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi0 = pb[l/2] & 0xF; const uint8_t vi0 = y[i].qs[l/2] & 0xF;
const uint8_t vi1 = pb[l/2] >> 4; const uint8_t vi1 = y[i].qs[l/2] >> 4;
hist[vi0]++; hist[vi0]++;
hist[vi1]++; hist[vi1]++;
} }
pb += bs;
} }
} }
return (n/k)*row_size; return (n/QK*sizeof(block_q4_1));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

4
ggml.h
View File

@ -748,8 +748,8 @@ enum ggml_opt_result ggml_opt(
// quantization // quantization
// //
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist); size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist); size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
// //
// system info // system info

View File

@ -1345,7 +1345,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
// //
// TODO: reuse code from the llama_model_load() somehow // TODO: reuse code from the llama_model_load() somehow
bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype, int qk) { static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
ggml_type type = GGML_TYPE_Q4_1; ggml_type type = GGML_TYPE_Q4_1;
switch (itype) { switch (itype) {
@ -1568,11 +1568,11 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data()); cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
{ {
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data()); cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; } break;
default: default:
{ {
@ -1711,9 +1711,8 @@ void llama_free(struct llama_context * ctx) {
int llama_model_quantize( int llama_model_quantize(
const char * fname_inp, const char * fname_inp,
const char * fname_out, const char * fname_out,
int itype, int itype) {
int qk) { if (!llama_model_quantize_internal(fname_inp, fname_out, itype)) {
if (!llama_model_quantize_internal(fname_inp, fname_out, itype, qk)) {
fprintf(stderr, "%s: failed to quantize\n", __func__); fprintf(stderr, "%s: failed to quantize\n", __func__);
return 1; return 1;
} }

View File

@ -81,8 +81,7 @@ extern "C" {
LLAMA_API int llama_model_quantize( LLAMA_API int llama_model_quantize(
const char * fname_inp, const char * fname_inp,
const char * fname_out, const char * fname_out,
int itype, int itype);
int qk);
// Run the llama inference to obtain the logits and probabilities for the next token. // Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process // tokens + n_tokens is the provided batch of new tokens to process

View File

@ -13,7 +13,7 @@ int main(void) {
src[i] = (float)(i + 1); src[i] = (float)(i + 1);
} }
size_t size = ggml_quantize_q4_0(src, dst, QK, QK, QK, hist); size_t size = ggml_quantize_q4_0(src, dst, QK, QK, hist);
assert(size == 20); assert(size == 20);
float max_result = ((float *)dst)[0]; float max_result = ((float *)dst)[0];
float max_expected = src[31] / ((1 << 3) - 1); float max_expected = src[31] / ((1 << 3) - 1);
@ -24,7 +24,7 @@ int main(void) {
assert(q4_result == q4_expected); assert(q4_result == q4_expected);
} }
size = ggml_quantize_q4_1(src, dst, QK, QK, QK, hist); size = ggml_quantize_q4_1(src, dst, QK, QK, hist);
assert(size == 24); assert(size == 24);
float delta_result = ((float *)dst)[0]; float delta_result = ((float *)dst)[0];
float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1); float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1);