mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 05:42:22 +01:00
clean code 2
This commit is contained in:
parent
1c5a8b7fec
commit
3a0f8b0697
@ -2805,7 +2805,7 @@ def parse_args() -> argparse.Namespace:
|
|||||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16",
|
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -2865,7 +2865,6 @@ def main() -> None:
|
|||||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||||
"i2" : gguf.LlamaFileType.MOSTLY_I2,
|
|
||||||
"auto": gguf.LlamaFileType.GUESSED,
|
"auto": gguf.LlamaFileType.GUESSED,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
35
ggml.c
35
ggml.c
@ -2724,7 +2724,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
|
|
||||||
"CROSS_ENTROPY_LOSS",
|
"CROSS_ENTROPY_LOSS",
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
@ -2813,7 +2812,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
|
|
||||||
"cross_entropy_loss(x,y)",
|
"cross_entropy_loss(x,y)",
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
@ -3078,10 +3076,9 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
|||||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||||
}
|
}
|
||||||
if(tensor->type == 31){
|
if(tensor->type == GGML_TYPE_I2_S){
|
||||||
nbytes = nbytes / 4 + 32;
|
nbytes = nbytes / 4 + 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
||||||
@ -3107,6 +3104,7 @@ GGML_CALL size_t ggml_type_size(enum ggml_type type) {
|
|||||||
|
|
||||||
GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
|
GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
|
||||||
assert(ne % ggml_blck_size(type) == 0);
|
assert(ne % ggml_blck_size(type) == 0);
|
||||||
|
if (type == GGML_TYPE_I2_S) ne /= 4;
|
||||||
return ggml_type_size(type)*ne/ggml_blck_size(type);
|
return ggml_type_size(type)*ne/ggml_blck_size(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12333,11 +12331,11 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
if (src0->type == 31) {
|
// if (src0->type == 31) {
|
||||||
row_size = ne10;
|
// row_size = ne10;
|
||||||
}
|
// }
|
||||||
|
|
||||||
assert(ne12 % ne02 == 0);
|
assert(ne12 % ne02 == 0);
|
||||||
assert(ne13 % ne03 == 0);
|
assert(ne13 % ne03 == 0);
|
||||||
@ -12351,9 +12349,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|||||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||||
// 16 * 2, accounting for mmla kernels
|
// 16 * 2, accounting for mmla kernels
|
||||||
float tmp[32];
|
float tmp[32];
|
||||||
uint8_t *i_weight = (uint8_t*) (src0->data);
|
float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4));
|
||||||
float * scale = (float * )((i_weight) + (ne00 * ne01 / 4));
|
const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10));
|
||||||
float * act_scales = (float*) ((char *) wdata + (ne11 * ne10));
|
|
||||||
|
|
||||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||||
@ -12380,7 +12377,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|||||||
(src1_cont || src1->type != vec_dot_type
|
(src1_cont || src1->type != vec_dot_type
|
||||||
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
||||||
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
||||||
|
|
||||||
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||||
|
|
||||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||||
@ -12388,13 +12384,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
if (src0->type == GGML_TYPE_I2_S) {
|
if (src0->type == GGML_TYPE_I2_S) {
|
||||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
|
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
|
||||||
} else {
|
} else {
|
||||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||||
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
||||||
|
@ -925,7 +925,6 @@ class GGMLQuantizationType(IntEnum):
|
|||||||
F64 = 28
|
F64 = 28
|
||||||
IQ1_M = 29
|
IQ1_M = 29
|
||||||
BF16 = 30
|
BF16 = 30
|
||||||
I2 = 31
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||||
@ -967,7 +966,6 @@ class LlamaFileType(IntEnum):
|
|||||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||||
MOSTLY_BF16 = 32 # except 1d tensors
|
MOSTLY_BF16 = 32 # except 1d tensors
|
||||||
MOSTLY_I2 = 33 # except 1d tensors
|
|
||||||
|
|
||||||
GUESSED = 1024 # not specified in the model file
|
GUESSED = 1024 # not specified in the model file
|
||||||
|
|
||||||
@ -1034,7 +1032,6 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
|||||||
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
|
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
|
||||||
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
|
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
|
||||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
|
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
|
||||||
GGMLQuantizationType.I2: (1, 1),
|
|
||||||
GGMLQuantizationType.I8: (1, 1),
|
GGMLQuantizationType.I8: (1, 1),
|
||||||
GGMLQuantizationType.I16: (1, 2),
|
GGMLQuantizationType.I16: (1, 2),
|
||||||
GGMLQuantizationType.I32: (1, 4),
|
GGMLQuantizationType.I32: (1, 4),
|
||||||
|
@ -225,10 +225,8 @@ class GGUFWriter:
|
|||||||
dtype = GGMLQuantizationType.I32
|
dtype = GGMLQuantizationType.I32
|
||||||
elif tensor_dtype == np.int64:
|
elif tensor_dtype == np.int64:
|
||||||
dtype = GGMLQuantizationType.I64
|
dtype = GGMLQuantizationType.I64
|
||||||
elif tensor_dtype == np.uint8:
|
|
||||||
dtype = GGMLQuantizationType.I2
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now")
|
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
|
||||||
else:
|
else:
|
||||||
dtype = raw_dtype
|
dtype = raw_dtype
|
||||||
if tensor_dtype == np.uint8:
|
if tensor_dtype == np.uint8:
|
||||||
@ -239,10 +237,7 @@ class GGUFWriter:
|
|||||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||||
self.ti_data += self._pack("I", dtype)
|
self.ti_data += self._pack("I", dtype)
|
||||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
self.ti_data += self._pack("Q", self.offset_tensor)
|
||||||
if dtype == GGMLQuantizationType.I2:
|
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment
|
|
||||||
else:
|
|
||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
|
||||||
self.ti_data_count += 1
|
self.ti_data_count += 1
|
||||||
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
@ -257,9 +252,7 @@ class GGUFWriter:
|
|||||||
self.temp_file = fp
|
self.temp_file = fp
|
||||||
|
|
||||||
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
||||||
|
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
||||||
if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")):
|
|
||||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
self.tensors.append(tensor)
|
self.tensors.append(tensor)
|
||||||
|
97
llama.cpp
97
llama.cpp
@ -262,7 +262,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_OLMO, "olmo" },
|
{ LLM_ARCH_OLMO, "olmo" },
|
||||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||||
{ LLM_ARCH_BITNET, "bitnet" },
|
{ LLM_ARCH_BITNET, "bitnet" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3193,6 +3193,7 @@ struct llama_model_loader {
|
|||||||
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||||
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
||||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||||
|
|
||||||
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
|
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
|
||||||
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
||||||
}
|
}
|
||||||
@ -7029,7 +7030,6 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
struct ggml_tensor * wo_b,
|
struct ggml_tensor * wo_b,
|
||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * kq_mask,
|
struct ggml_tensor * kq_mask,
|
||||||
struct ggml_tensor * attn_sub_norm,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
@ -7124,15 +7124,6 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
cb(cur, "kqv_merged_cont", il);
|
cb(cur, "kqv_merged_cont", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_BITNET)
|
|
||||||
{
|
|
||||||
cur = llm_build_norm(ctx, cur, hparams,
|
|
||||||
attn_sub_norm, NULL,
|
|
||||||
LLM_NORM_RMS, cb, il);
|
|
||||||
cb(cur, "attn_sub_norm", il);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_build_forward_expand(graph, cur);
|
ggml_build_forward_expand(graph, cur);
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx, wo, cur);
|
cur = ggml_mul_mat(ctx, wo, cur);
|
||||||
@ -7178,7 +7169,7 @@ static struct ggml_tensor * llm_build_kv(
|
|||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
||||||
q_cur, kq_mask, nullptr, n_tokens, n_kv, kq_scale, cb, il);
|
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
@ -11590,10 +11581,84 @@ struct llm_build_context {
|
|||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||||
cur = llm_build_kqv(ctx0, model, hparams, cparams, kv_self, gf,
|
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
Qcur, KQ_mask, model.layers[il].attn_sub_norm, n_tokens, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il
|
const int64_t n_head = hparams.n_head;
|
||||||
);
|
const int64_t n_head_kv = hparams.n_head_kv;
|
||||||
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
|
struct ggml_tensor * q_cur = Qcur;
|
||||||
|
struct ggml_tensor * kq_mask = KQ_mask;
|
||||||
|
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||||
|
struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
|
||||||
|
struct ggml_cgraph * graph = gf;
|
||||||
|
struct ggml_tensor * wo = model.layers[il].wo;
|
||||||
|
struct ggml_tensor * cur_attn;
|
||||||
|
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||||
|
cb(q, "q", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * k =
|
||||||
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
|
n_embd_head_k, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
|
0);
|
||||||
|
cb(k, "k", il);
|
||||||
|
|
||||||
|
if (cparams.flash_attn) {
|
||||||
|
|
||||||
|
// split cached v into n_head heads (not transposed)
|
||||||
|
struct ggml_tensor * v =
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
||||||
|
0);
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
|
||||||
|
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
|
|
||||||
|
// split cached v into n_head heads
|
||||||
|
struct ggml_tensor * v =
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_kv, n_embd_head_v, n_head_kv,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||||
|
0);
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||||
|
cb(kqv, "kqv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
cb(kqv_merged, "kqv_merged", il);
|
||||||
|
|
||||||
|
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||||
|
cb(cur_attn, "kqv_merged_cont", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
|
||||||
|
attn_sub_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur_attn, "attn_sub_norm", il);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(graph, cur_attn);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, wo, cur_attn);
|
||||||
|
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user