clean code

This commit is contained in:
root 2024-06-09 20:22:03 +08:00
parent dbee0a86c1
commit 1c5a8b7fec
7 changed files with 64 additions and 216 deletions

View File

@ -26,7 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "I2_S", LLAMA_FTYPE_MOSTLY_I2, " 2 bpw per-tensor", },
{ "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },

View File

@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) {
}
#endif //__loongarch_asx
void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) {
int8_t* dst = (int8_t*)y;
double min = 0.00001;
double max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, (double)fabs(x[i]));
}
float s = 127 / max;
act_scales[0] = s;
float temp;
for (int i = 0; i < n; ++i) {
temp = round(x[i] * s);
if (temp > 127) temp = 127;
if (temp < -128) temp = -128;
dst[i] = (int8_t)(temp);
}
}
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
@ -3308,7 +3326,9 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
// 2 bits per weight
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row);
UNUSED(quant_weights);
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
int n = nrow * n_per_row;
@ -3326,7 +3346,7 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr
q8[i] = 0;
continue;
}
q8[i] = src[i] * i2_scale > 0 ? 1 : 3;
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3;
}
// q8 -> 0, 1, 3
@ -3773,14 +3793,19 @@ static inline __m128i get_scale_shuffle(int i) {
//====================================== I2 ===============================================
void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const uint8_t * restrict x = vx;
const int8_t * restrict y = vy;
UNUSED(bs);
UNUSED(bx);
UNUSED(by);
UNUSED(nrc);
int sumi = 0;
for (int i = 0; i < n / 4; i++) {
int8_t* weight = (const int8_t *)(i2_q8 + x[i]);
const int8_t* weight = (const int8_t *)(i2_q8 + x[i]);
sumi += (int)y[i*4+0] * weight[0];
sumi += (int)y[i*4+1] * weight[1];
sumi += (int)y[i*4+2] * weight[2];
@ -14431,7 +14456,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_I2:
case GGML_TYPE_I2_S:
// nothing to validate
break;
default:

View File

@ -51,6 +51,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -99,7 +100,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

233
ggml.c
View File

@ -569,15 +569,6 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I2] = {
.type_name = "i2",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = true,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_I8] = {
.type_name = "i8",
.blck_size = 1,
@ -922,6 +913,21 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
.vec_dot_type = GGML_TYPE_BF16,
.nrows = 1,
},
[GGML_TYPE_I2_S] = {
.type_name = "i2_s",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = true,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_i8_s,
.vec_dot_type = GGML_TYPE_I8_S,
.nrows = 1,
},
[GGML_TYPE_I8_S] = {
.type_name = "i8_s",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = true,
}
};
@ -2630,33 +2636,6 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
*s = idx;
}
inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, float * x, float min) {
float max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, fabs(x[i]));
}
*s = max;
}
inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) {
for (int i = 0; i < n; ++i) {
s[i] = round(x[i] * scale);
if (s[i] > max) s[i] = max;
if (s[i] < min) s[i] = min;
s[i] /= scale;
}
}
inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) {
float temp;
for (int i = 0; i < n; ++i) {
temp = round(s[i] * scale);
if (temp > max) temp = max;
if (temp < min) temp = min;
inp[i] = (int8_t)(temp);
}
}
//
// data types
//
@ -12409,8 +12388,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
if (src0->type == 31) {
// printf("row->%ld\n", (ir0 * nb01 / 4));
if (src0->type == GGML_TYPE_I2_S) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
} else {
@ -12426,164 +12404,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
}
static void ggml_compute_forward_bitnet_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
struct ggml_compute_state * state) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
UNUSED(r2);
UNUSED(r3);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
}
atomic_store(&state->shared->current_chunk, nth);
char * wdata = params->wdata;
float* act_scales = (float*) ((char *) wdata + (ne11 * ne10));
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
for (int64_t i11 = 0; i11 < ne11; i11++) {
float rowmax = 0.00001;
ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001);
float s = 127 / rowmax;
act_scales[i11] = s;
ggml_vec_scaleroundclamp_f32_v2(ne10,
(float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13),
(int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)),
s, -128, 127);
}
}
}
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
// atomic_store(&state->shared->current_chunk, nth);
// // char * wdata = params->wdata;
// const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10);
// // printf("vec_dot_type:%d\n", vec_dot_type);
// // printf("row_size:%ld\n", row_size);
// assert(params->wsize >= ne11*ne12*ne13*row_size);
// GGML_ASSERT(src1->type == GGML_TYPE_F32);
// for (int64_t i13 = 0; i13 < ne13; ++i13) {
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
// quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
// wdata += row_size;
// }
// }
// }
return;
}
if (params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
// This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
const int64_t nr0 = ne0;
// This is the size of the rest of the dimensions of the result
const int64_t nr1 = ne1 * ne2 * ne3;
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
int64_t num_rows_per_vec_dot = 1;
// TODO: currently the mmla kernels support only even numbered rows/cols.
// this check can be removed once they are extended to support odd numbered rows/cols too
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
num_rows_per_vec_dot = 1;
}
// Now select a reasonable chunk size.
int chunk_size = 16;
// We need to step up the size if it's small
if (nr0 == 1 || nr1 == 1) {
chunk_size = 64;
}
// distribute the work across the inner or outer loop based on which one is larger
// The number of chunks in the 0/1 dim.
// CEIL(nr0/chunk_size)
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
// distribute the thread work across the inner or outer loop based on which one is larger
nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
}
// The number of elements in each chunk
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
//if (ith == 0)
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
}
}
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@ -12597,11 +12417,6 @@ static void ggml_compute_forward_mul_mat(
GGML_TENSOR_BINARY_OP_LOCALS
if (src0->type == 31) {
ggml_compute_forward_bitnet_mul_mat(params, dst, state);
return;
}
const int ith = params->ith;
const int nth = params->nth;
@ -12751,8 +12566,13 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
if (src0->type == GGML_TYPE_I2_S) {
float* act_scales = (float*) ((char *) wdata + (ne11 * ne10));
quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11);
} else {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
@ -14469,7 +14289,8 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_I2:
case GGML_TYPE_I2_S:
case GGML_TYPE_I8_S:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
@ -21727,7 +21548,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_I2: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16:
{
size_t elemsize = sizeof(ggml_fp16_t);
@ -21750,7 +21571,7 @@ size_t ggml_quantize_chunk(
assert(false);
}
if (type == GGML_TYPE_I2) {
if (type == GGML_TYPE_I2_S) {
result = nrows * row_size / 4 + 32;
} else {
GGML_ASSERT(result == nrows * row_size);

3
ggml.h
View File

@ -377,7 +377,8 @@ extern "C" {
GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30,
GGML_TYPE_I2 = 31,
GGML_TYPE_I2_S = 31,
GGML_TYPE_I8_S = 32,
GGML_TYPE_COUNT,
};

View File

@ -15634,7 +15634,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
case LLAMA_FTYPE_MOSTLY_I2: default_type = GGML_TYPE_I2; break;
case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break;
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S:

View File

@ -156,7 +156,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_I2 = 33,
LLAMA_FTYPE_MOSTLY_I2_S = 33,
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};