mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
ggml : poc for normalizing weights for better quantization
This commit is contained in:
parent
1a941869cb
commit
675425563c
30
ggml-cuda.cu
30
ggml-cuda.cu
@ -74,14 +74,17 @@ typedef void (*ggml_cuda_op_t)(
|
||||
// QR = QK / number of values before dequantization
|
||||
// QI = number of 32 bit integers before dequantization
|
||||
|
||||
#define Q4_0DM (1.0f/8.0f)
|
||||
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QI4_0 (QK4_0 / (4 * QR4_0))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
int8_t d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
@ -103,16 +106,20 @@ typedef struct {
|
||||
} block_q5_0;
|
||||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||
|
||||
#define Q5_1DM (2.0f/31.0f)
|
||||
#define Q5_1MM (2.0f )
|
||||
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
|
||||
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
|
||||
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QI5_1 (QK5_1 / (4 * QR5_1))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uint8_t dm; // 4-bit delta + 4-bit min
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||
} block_q5_1;
|
||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
@ -360,7 +367,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
const dfloat d = Q4_0D(x[ib].d);
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
@ -422,8 +429,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
|
||||
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
const dfloat m = x[ib].m;
|
||||
const dfloat d = Q5_1D(x[ib].dm);
|
||||
const dfloat m = Q5_1M(x[ib].dm);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
@ -1336,7 +1343,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
||||
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
|
||||
|
||||
const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
|
||||
const float d = Q4_0D(bq4_0->d) * __half2float(bq8_1->d);
|
||||
|
||||
// subtract 8 from each quantized value
|
||||
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
|
||||
@ -1419,14 +1426,15 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
|
||||
|
||||
// TODO: fix misaligned access
|
||||
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
|
||||
const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
|
||||
const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
|
||||
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
|
||||
|
||||
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
|
||||
const float m = bq5_1->m;
|
||||
const float d = Q5_1D(bq5_1->dm) * __half2float(bq8_1->d);
|
||||
const float m = Q5_1M(bq5_1->dm);
|
||||
const float s = bq8_1->s;
|
||||
|
||||
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
|
||||
|
100
ggml.c
100
ggml.c
@ -892,12 +892,16 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// we know the values are in the [-1 .. 1] range, so abs(d) cannot be more than 1/8 when using 4 bits
|
||||
#define Q4_0DM (1.0f/8.0f)
|
||||
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
int8_t d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
@ -915,14 +919,21 @@ typedef struct {
|
||||
} block_q5_0;
|
||||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||
|
||||
// we know the values are in the [-1 .. 1] range, so:
|
||||
// - d is unsigned 4-bit that represents maximum value of 2.0/31 when using 5 bits
|
||||
// - m is unsigned 4-bit that represents offset from -1.0 which cannot be more than 2.0
|
||||
#define Q5_1DM (2.0f/31.0f)
|
||||
#define Q5_1MM (2.0f )
|
||||
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
|
||||
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
|
||||
|
||||
#define QK5_1 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
ggml_fp16_t m; // min
|
||||
uint8_t dm; // 4-bit delta + 4-bit min
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||
} block_q5_1;
|
||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
typedef struct {
|
||||
@ -959,10 +970,13 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
float d = max / -8;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
y[i].d = (int8_t)(ceilf((127.0f * d) / Q4_0DM));
|
||||
|
||||
d = Q4_0D(y[i].d);
|
||||
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const float x0 = x[i*qk + 0 + j]*id;
|
||||
@ -1088,11 +1102,17 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
|
||||
if (v > max) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 5) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
y[i].dm = (uint8_t)(floorf((15.0f * (min + 1.0f)) / Q5_1MM)) << 4;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
y[i].m = GGML_FP32_TO_FP16(min);
|
||||
min = Q5_1M(y[i].dm);
|
||||
|
||||
float d = (max - min) / ((1 << 5) - 1);
|
||||
|
||||
y[i].dm |= (uint8_t)(ceilf((15.0f * d) / Q5_1DM));
|
||||
|
||||
d = Q5_1D(y[i].dm);
|
||||
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
uint32_t qh = 0;
|
||||
|
||||
@ -1530,7 +1550,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const float d = Q4_0D(x[i].d);
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
||||
@ -1597,8 +1617,8 @@ static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const float m = GGML_FP16_TO_FP32(x[i].m);
|
||||
const float d = Q5_1D(x[i].dm);
|
||||
const float m = Q5_1M(x[i].dm);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[i].qh, sizeof(qh));
|
||||
@ -2407,8 +2427,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
@ -2425,8 +2445,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -2438,7 +2458,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
|
||||
@ -2462,7 +2482,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// Compute combined scale for the block
|
||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
const __m128i off = _mm_set1_epi8(8);
|
||||
@ -2504,7 +2524,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 0 and 1
|
||||
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
|
||||
const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
|
||||
|
||||
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
|
||||
|
||||
@ -2522,7 +2542,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 2 and 3
|
||||
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
|
||||
const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
|
||||
|
||||
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
|
||||
|
||||
@ -2555,7 +2575,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 0 and 1
|
||||
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
||||
|
||||
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
|
||||
@ -2573,7 +2593,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 2 and 3
|
||||
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
|
||||
const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
|
||||
|
||||
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
|
||||
|
||||
@ -2621,7 +2641,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
|
||||
}
|
||||
|
||||
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
||||
sumf += sumi*Q4_0D(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
@ -3026,8 +3046,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
|
||||
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
|
||||
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
|
||||
summs0 += Q5_1M(x0->dm) * y0->s;
|
||||
summs1 += Q5_1M(x1->dm) * y1->s;
|
||||
|
||||
// extract the 5th bit via lookup table ((b) << 4)
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
@ -3072,10 +3092,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
||||
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
||||
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), Q5_1D(x0->dm)*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
||||
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
||||
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), Q5_1D(x1->dm)*y1->d);
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
||||
@ -3092,8 +3112,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q5_1D(x0->dm)*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q5_1D(x1->dm)*y1->d);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -3111,7 +3131,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
const block_q5_1 * restrict x0 = &x[i];
|
||||
const block_q8_1 * restrict y0 = &y[i];
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
|
||||
summs += Q5_1M(x0->dm) * y0->s;
|
||||
|
||||
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
||||
|
||||
@ -3158,7 +3178,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
||||
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
||||
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
|
||||
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
|
||||
wasm_f32x4_splat(Q5_1D(x0->dm) * y0->d)));
|
||||
}
|
||||
|
||||
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
||||
@ -3171,9 +3191,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
||||
const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm));
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
summs += Q5_1M(x[i].dm) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
@ -3198,9 +3218,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
||||
const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm));
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
summs += Q5_1M(x[i].dm) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
@ -3243,7 +3263,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
|
||||
}
|
||||
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
||||
sumf += (Q5_1D(x[i].dm)*y[i].d)*sumi + Q5_1M(x[i].dm)*y[i].s;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
@ -5470,7 +5490,7 @@ struct ggml_tensor * ggml_sum_rows(
|
||||
}
|
||||
|
||||
int64_t ne[4] = {1,1,1,1};
|
||||
for (int i=1; i<a->n_dims; ++i) {
|
||||
for (int i = 1; i < a->n_dims; ++i) {
|
||||
ne[i] = a->ne[i];
|
||||
}
|
||||
|
||||
|
6
ggml.h
6
ggml.h
@ -281,9 +281,9 @@ extern "C" {
|
||||
GGML_TYPE_Q5_K = 13,
|
||||
GGML_TYPE_Q6_K = 14,
|
||||
GGML_TYPE_Q8_K = 15,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
GGML_TYPE_I8 = 16,
|
||||
GGML_TYPE_I16 = 17,
|
||||
GGML_TYPE_I32 = 18,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
|
125
llama.cpp
125
llama.cpp
@ -119,7 +119,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
||||
{
|
||||
static std::map<e_model, size_t> k_sizes = {
|
||||
{ MODEL_3B, 128ull * MB },
|
||||
{ MODEL_7B, 160ull * MB },
|
||||
{ MODEL_7B, 200ull * MB },
|
||||
{ MODEL_13B, 192ull * MB },
|
||||
{ MODEL_30B, 256ull * MB },
|
||||
{ MODEL_65B, 384ull * MB }, // guess
|
||||
@ -229,6 +229,11 @@ struct llama_layer {
|
||||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
|
||||
struct ggml_tensor * wq_a;
|
||||
struct ggml_tensor * wk_a;
|
||||
struct ggml_tensor * wv_a;
|
||||
struct ggml_tensor * wo_a;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm;
|
||||
|
||||
@ -236,6 +241,10 @@ struct llama_layer {
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
|
||||
struct ggml_tensor * w1_a;
|
||||
struct ggml_tensor * w2_a;
|
||||
struct ggml_tensor * w3_a;
|
||||
};
|
||||
|
||||
struct llama_kv_cache {
|
||||
@ -1208,17 +1217,29 @@ static void llama_model_load_internal(
|
||||
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.wq_a = ml->get_tensor(layers_i + ".attention.wq.weight.a", {n_embd}, backend);
|
||||
layer.wk_a = ml->get_tensor(layers_i + ".attention.wk.weight.a", {n_embd_gqa}, backend);
|
||||
layer.wv_a = ml->get_tensor(layers_i + ".attention.wv.weight.a", {n_embd_gqa}, backend);
|
||||
layer.wo_a = ml->get_tensor(layers_i + ".attention.wo.weight.a", {n_embd}, backend);
|
||||
|
||||
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
|
||||
|
||||
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split);
|
||||
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split);
|
||||
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split);
|
||||
|
||||
layer.w1_a = ml->get_tensor(layers_i + ".feed_forward.w1.weight.a", { n_ff}, backend);
|
||||
layer.w2_a = ml->get_tensor(layers_i + ".feed_forward.w2.weight.a", {n_embd}, backend);
|
||||
layer.w3_a = ml->get_tensor(layers_i + ".feed_forward.w3.weight.a", { n_ff}, backend);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
|
||||
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
|
||||
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
|
||||
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3) +
|
||||
ggml_nbytes(layer.wq_a) + ggml_nbytes(layer.wk_a) + ggml_nbytes(layer.wv_a) +
|
||||
ggml_nbytes(layer.wo_a) + ggml_nbytes(layer.w1_a) + ggml_nbytes(layer.w2_a) +
|
||||
ggml_nbytes(layer.w3_a);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1360,6 +1381,34 @@ static bool llama_model_load(
|
||||
}
|
||||
}
|
||||
|
||||
// computes: Z = (X @ Y) * a
|
||||
// a is vector with size equal to rows of X. each element is the scaling factor used to normalize X's rows
|
||||
// the ggml_mul() is broadcasted row-wise to restore the normalization
|
||||
struct ggml_tensor * ggml_mul_mat_ex(
|
||||
struct ggml_context * ctx0,
|
||||
struct ggml_tensor * t,
|
||||
struct ggml_tensor * a,
|
||||
//struct ggml_tensor * b,
|
||||
struct ggml_tensor * cur,
|
||||
offload_func_t offload_func) {
|
||||
cur = ggml_mul_mat(ctx0, t, cur);
|
||||
offload_func(cur);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, a);
|
||||
offload_func(cur);
|
||||
|
||||
return cur;
|
||||
|
||||
//struct ggml_tensor * tmp = ggml_mul_mat(ctx0, t, cur);
|
||||
//tmp = ggml_mul(ctx0, tmp, a);
|
||||
//cur = ggml_add(ctx0, tmp,
|
||||
// ggml_mul(ctx0,
|
||||
// ggml_repeat(ctx0, ggml_sum_rows(ctx0, cur), tmp),
|
||||
// b)
|
||||
// );
|
||||
//return cur;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
@ -1502,12 +1551,10 @@ static bool llama_eval_internal(
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
offload_func_kq(tmpk);
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat_ex(ctx0, model.layers[il].wk, model.layers[il].wk_a, cur, offload_func_kq);
|
||||
ggml_set_name(tmpk, "tmpk");
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
offload_func_kq(tmpq);
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat_ex(ctx0, model.layers[il].wq, model.layers[il].wq_a, cur, offload_func_kq);
|
||||
ggml_set_name(tmpq, "tmpq");
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
@ -1522,8 +1569,7 @@ static bool llama_eval_internal(
|
||||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
offload_func_v(tmpv);
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat_ex(ctx0, model.layers[il].wv, model.layers[il].wv_a, cur, offload_func_v);
|
||||
ggml_set_name(tmpv, "tmpv");
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
|
||||
@ -1620,10 +1666,7 @@ static bool llama_eval_internal(
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
// projection (no bias)
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat_ex(ctx0, model.layers[il].wo, model.layers[il].wo_a, cur, offload_func);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
@ -1647,16 +1690,10 @@ static bool llama_eval_internal(
|
||||
ggml_set_name(cur, "ffn_norm");
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
offload_func(tmp);
|
||||
struct ggml_tensor * tmp = ggml_mul_mat_ex(ctx0, model.layers[il].w3, model.layers[il].w3_a, cur, offload_func);
|
||||
ggml_set_name(tmp, "result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat_ex(ctx0, model.layers[il].w1, model.layers[il].w1_a, cur, offload_func);
|
||||
ggml_set_name(cur, "result_w1");
|
||||
|
||||
// SILU activation
|
||||
@ -1668,10 +1705,7 @@ static bool llama_eval_internal(
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "silu_x_result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat_ex(ctx0, model.layers[il].w2, model.layers[il].w2_a, cur, offload_func);
|
||||
ggml_set_name(cur, "result_w2");
|
||||
}
|
||||
|
||||
@ -2936,7 +2970,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
} else {
|
||||
new_type = quantized_type;
|
||||
#ifdef GGML_USE_K_QUANTS
|
||||
if (tensor.name == "output.weight") {
|
||||
if (tensor.name == "output.weight" || tensor.name == "tok_embeddings.weight") {
|
||||
int nx = tensor.ne.at(0);
|
||||
int ny = tensor.ne.at(1);
|
||||
if (nx % QK_K == 0 && ny % QK_K == 0) {
|
||||
@ -2997,6 +3031,43 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
f32_data = (float *) f32_conv_buf.addr;
|
||||
}
|
||||
|
||||
// TODO: this is temporary since we only implemented Q4_0 and Q5_1 as POC
|
||||
if (new_type == GGML_TYPE_Q4_0 || new_type == GGML_TYPE_Q5_1) {
|
||||
//printf("\n dims: %d x %d\n", tensor.ne.at(0), tensor.ne.at(1));
|
||||
|
||||
const uint32_t nr = tensor.ne.at(1);
|
||||
|
||||
std::vector<float> va(nr);
|
||||
std::vector<float> vb(nr);
|
||||
|
||||
// normalize to -1..1 per rows
|
||||
for (uint32_t r = 0; r < nr; ++r) {
|
||||
const uint32_t n = tensor.ne.at(0);
|
||||
float * p = f32_data + r * n;
|
||||
|
||||
float amax = 0.0f;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
amax = std::max(amax, std::abs(p[i]));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
p[i] = p[i] / amax;
|
||||
}
|
||||
|
||||
va[r] = amax;
|
||||
}
|
||||
|
||||
{
|
||||
llama_load_tensor ta;
|
||||
ta.name = tensor.name + ".a";
|
||||
ta.type = GGML_TYPE_F32;
|
||||
ta.ne = std::vector<uint32_t>(1, nr);
|
||||
ta.size = nr * sizeof(float);
|
||||
ta.data = (uint8_t *) va.data();
|
||||
file_saver.write_tensor(ta, GGML_TYPE_F32, ta.data, ta.size);
|
||||
}
|
||||
}
|
||||
|
||||
printf("quantizing to %s .. ", ggml_type_name(new_type));
|
||||
fflush(stdout);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user