clean code 2

This commit is contained in:
root 2024-06-09 21:15:02 +08:00
parent 1c5a8b7fec
commit 3a0f8b0697
5 changed files with 100 additions and 51 deletions

View File

@ -2805,7 +2805,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
@ -2865,7 +2865,6 @@ def main() -> None:
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"i2" : gguf.LlamaFileType.MOSTLY_I2,
"auto": gguf.LlamaFileType.GUESSED,
}

23
ggml.c
View File

@ -2724,7 +2724,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
@ -2813,7 +2812,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
@ -3078,10 +3076,9 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
if(tensor->type == 31){
if(tensor->type == GGML_TYPE_I2_S){
nbytes = nbytes / 4 + 32;
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
@ -3107,6 +3104,7 @@ GGML_CALL size_t ggml_type_size(enum ggml_type type) {
GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(ne % ggml_blck_size(type) == 0);
if (type == GGML_TYPE_I2_S) ne /= 4;
return ggml_type_size(type)*ne/ggml_blck_size(type);
}
@ -12333,11 +12331,11 @@ static void ggml_compute_forward_mul_mat_one_chunk(
return;
}
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
size_t row_size = ggml_row_size(vec_dot_type, ne10);
if (src0->type == 31) {
row_size = ne10;
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
// if (src0->type == 31) {
// row_size = ne10;
// }
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
@ -12351,9 +12349,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
uint8_t *i_weight = (uint8_t*) (src0->data);
float * scale = (float * )((i_weight) + (ne00 * ne01 / 4));
float * act_scales = (float*) ((char *) wdata + (ne11 * ne10));
float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4));
const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10));
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
@ -12380,7 +12377,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
@ -12394,7 +12390,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
} else {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));

View File

@ -925,7 +925,6 @@ class GGMLQuantizationType(IntEnum):
F64 = 28
IQ1_M = 29
BF16 = 30
I2 = 31
# TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -967,7 +966,6 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_I2 = 33 # except 1d tensors
GUESSED = 1024 # not specified in the model file
@ -1034,7 +1032,6 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
GGMLQuantizationType.I2: (1, 1),
GGMLQuantizationType.I8: (1, 1),
GGMLQuantizationType.I16: (1, 2),
GGMLQuantizationType.I32: (1, 4),

View File

@ -225,10 +225,8 @@ class GGUFWriter:
dtype = GGMLQuantizationType.I32
elif tensor_dtype == np.int64:
dtype = GGMLQuantizationType.I64
elif tensor_dtype == np.uint8:
dtype = GGMLQuantizationType.I2
else:
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now")
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
else:
dtype = raw_dtype
if tensor_dtype == np.uint8:
@ -239,9 +237,6 @@ class GGUFWriter:
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
self.ti_data += self._pack("I", dtype)
self.ti_data += self._pack("Q", self.offset_tensor)
if dtype == GGMLQuantizationType.I2:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment
else:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
@ -257,8 +252,6 @@ class GGUFWriter:
self.temp_file = fp
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")):
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
if self.temp_file is None:

View File

@ -3193,6 +3193,7 @@ struct llama_model_loader {
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
}
@ -7029,7 +7030,6 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
struct ggml_tensor * attn_sub_norm,
int32_t n_tokens,
int32_t n_kv,
float kq_scale,
@ -7124,15 +7124,6 @@ static struct ggml_tensor * llm_build_kqv(
cb(cur, "kqv_merged_cont", il);
}
if (model.arch == LLM_ARCH_BITNET)
{
cur = llm_build_norm(ctx, cur, hparams,
attn_sub_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_sub_norm", il);
}
ggml_build_forward_expand(graph, cur);
cur = ggml_mul_mat(ctx, wo, cur);
@ -7178,7 +7169,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
q_cur, kq_mask, nullptr, n_tokens, n_kv, kq_scale, cb, il);
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il);
return cur;
@ -11590,10 +11581,84 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, model.layers[il].attn_sub_norm, n_tokens, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il
);
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
struct ggml_tensor * q_cur = Qcur;
struct ggml_tensor * kq_mask = KQ_mask;
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
struct ggml_cgraph * graph = gf;
struct ggml_tensor * wo = model.layers[il].wo;
struct ggml_tensor * cur_attn;
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
if (cparams.flash_attn) {
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv_self.size == n_ctx);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self.v_l[il])*n_ctx,
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur_attn, "kqv_merged_cont", il);
}
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
attn_sub_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur_attn, "attn_sub_norm", il);
ggml_build_forward_expand(graph, cur_attn);
cur = ggml_mul_mat(ctx0, wo, cur_attn);
cb(cur, "kqv_out", il);
}