mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
llama : add n_expert and n_expert_used to hparams + change quants
This commit is contained in:
parent
d1259b7b35
commit
e640cbe055
39
convert.py
39
convert.py
@ -158,7 +158,9 @@ class Params:
|
|||||||
n_ff: int
|
n_ff: int
|
||||||
n_head: int
|
n_head: int
|
||||||
n_head_kv: int
|
n_head_kv: int
|
||||||
f_norm_eps: float
|
n_experts: int | None = None
|
||||||
|
n_experts_used: int | None = None
|
||||||
|
f_norm_eps: float | None = None
|
||||||
|
|
||||||
rope_scaling_type: gguf.RopeScalingType | None = None
|
rope_scaling_type: gguf.RopeScalingType | None = None
|
||||||
f_rope_freq_base: float | None = None
|
f_rope_freq_base: float | None = None
|
||||||
@ -255,6 +257,9 @@ class Params:
|
|||||||
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
||||||
config = json.load(open(config_path))
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
|
n_experts = None
|
||||||
|
n_experts_used = None
|
||||||
|
|
||||||
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
||||||
if config.get("rope_theta") == 1000000:
|
if config.get("rope_theta") == 1000000:
|
||||||
# CodeLlama
|
# CodeLlama
|
||||||
@ -262,21 +267,21 @@ class Params:
|
|||||||
elif config["norm_eps"] == 1e-05:
|
elif config["norm_eps"] == 1e-05:
|
||||||
# LLaMA v2
|
# LLaMA v2
|
||||||
n_ctx = 4096
|
n_ctx = 4096
|
||||||
|
elif config["moe"]:
|
||||||
|
# Mixtral
|
||||||
|
n_ctx = 32768
|
||||||
else:
|
else:
|
||||||
# LLaMA v1
|
# LLaMA v1
|
||||||
n_ctx = 2048
|
n_ctx = 2048
|
||||||
|
|
||||||
# print model keys
|
if "layers.0.feed_forward.w1.weight" in model:
|
||||||
for k in model.keys():
|
|
||||||
print(k)
|
|
||||||
|
|
||||||
# check if MoE
|
|
||||||
if "layers.0.feed_forward.experts.0.w1.weight" in model:
|
|
||||||
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
|
|
||||||
n_ctx = 32768
|
|
||||||
else:
|
|
||||||
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
|
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
|
|
||||||
|
if config.get("moe"):
|
||||||
|
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
|
||||||
|
n_experts = config["moe"]["num_experts"]
|
||||||
|
n_experts_used = config["moe"]["num_experts_per_tok"]
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab = model["tok_embeddings.weight"].shape[0],
|
n_vocab = model["tok_embeddings.weight"].shape[0],
|
||||||
n_embd = config["dim"],
|
n_embd = config["dim"],
|
||||||
@ -285,6 +290,8 @@ class Params:
|
|||||||
n_ff = n_ff,
|
n_ff = n_ff,
|
||||||
n_head = (n_head := config["n_heads"]),
|
n_head = (n_head := config["n_heads"]),
|
||||||
n_head_kv = config.get("n_kv_heads", n_head),
|
n_head_kv = config.get("n_kv_heads", n_head),
|
||||||
|
n_experts = n_experts,
|
||||||
|
n_experts_used = n_experts_used,
|
||||||
f_norm_eps = config["norm_eps"],
|
f_norm_eps = config["norm_eps"],
|
||||||
f_rope_freq_base = config.get("rope_theta"),
|
f_rope_freq_base = config.get("rope_theta"),
|
||||||
)
|
)
|
||||||
@ -843,7 +850,17 @@ class OutputFile:
|
|||||||
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
|
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
|
||||||
self.gguf.add_head_count (params.n_head)
|
self.gguf.add_head_count (params.n_head)
|
||||||
self.gguf.add_head_count_kv (params.n_head_kv)
|
self.gguf.add_head_count_kv (params.n_head_kv)
|
||||||
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
|
||||||
|
if params.n_experts:
|
||||||
|
self.gguf.add_expert_count(params.n_experts)
|
||||||
|
|
||||||
|
if params.n_experts_used:
|
||||||
|
self.gguf.add_expert_used_count(params.n_experts_used)
|
||||||
|
|
||||||
|
if params.f_norm_eps:
|
||||||
|
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
|
||||||
|
else:
|
||||||
|
raise ValueError('f_norm_eps is None')
|
||||||
|
|
||||||
if params.f_rope_freq_base is not None:
|
if params.f_rope_freq_base is not None:
|
||||||
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
|
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
|
||||||
|
2
ggml.c
2
ggml.c
@ -4075,7 +4075,7 @@ struct ggml_tensor * ggml_mul_mat(
|
|||||||
|
|
||||||
struct ggml_tensor * ggml_mul_mat_id(
|
struct ggml_tensor * ggml_mul_mat_id(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * as[],
|
struct ggml_tensor * const as[],
|
||||||
int n_as,
|
int n_as,
|
||||||
struct ggml_tensor * ids,
|
struct ggml_tensor * ids,
|
||||||
int id,
|
int id,
|
||||||
|
2
ggml.h
2
ggml.h
@ -1051,7 +1051,7 @@ extern "C" {
|
|||||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * as[],
|
struct ggml_tensor * const as[],
|
||||||
int n_as,
|
int n_as,
|
||||||
struct ggml_tensor * ids,
|
struct ggml_tensor * ids,
|
||||||
int id,
|
int id,
|
||||||
|
@ -38,6 +38,8 @@ class Keys:
|
|||||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||||
|
EXPERT_COUNT = "{arch}.expert_count"
|
||||||
|
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
|
@ -339,6 +339,12 @@ class GGUFWriter:
|
|||||||
def add_clamp_kqv(self, value: float) -> None:
|
def add_clamp_kqv(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_expert_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_expert_used_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_layer_norm_eps(self, value: float) -> None:
|
def add_layer_norm_eps(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
80
llama.cpp
80
llama.cpp
@ -92,6 +92,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define LLAMA_MAX_NODES 8192
|
#define LLAMA_MAX_NODES 8192
|
||||||
|
#define LLAMA_MAX_EXPERTS 8
|
||||||
|
|
||||||
//
|
//
|
||||||
// logging
|
// logging
|
||||||
@ -231,6 +232,8 @@ enum llm_kv {
|
|||||||
LLM_KV_FEED_FORWARD_LENGTH,
|
LLM_KV_FEED_FORWARD_LENGTH,
|
||||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||||
|
LLM_KV_EXPERT_COUNT,
|
||||||
|
LLM_KV_EXPERT_USED_COUNT,
|
||||||
|
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||||
@ -281,6 +284,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||||
|
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||||
|
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||||
|
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||||
@ -1176,6 +1181,8 @@ struct llama_hparams {
|
|||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
uint32_t n_rot;
|
uint32_t n_rot;
|
||||||
uint32_t n_ff;
|
uint32_t n_ff;
|
||||||
|
uint32_t n_expert = 0;
|
||||||
|
uint32_t n_expert_used = 0;
|
||||||
|
|
||||||
float f_norm_eps;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
@ -1199,6 +1206,9 @@ struct llama_hparams {
|
|||||||
if (this->n_layer != other.n_layer) return true;
|
if (this->n_layer != other.n_layer) return true;
|
||||||
if (this->n_rot != other.n_rot) return true;
|
if (this->n_rot != other.n_rot) return true;
|
||||||
if (this->n_ff != other.n_ff) return true;
|
if (this->n_ff != other.n_ff) return true;
|
||||||
|
if (this->n_expert != other.n_expert) return true;
|
||||||
|
if (this->n_expert_used != other.n_expert_used) return true;
|
||||||
|
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||||
|
|
||||||
@ -1282,9 +1292,9 @@ struct llama_layer {
|
|||||||
|
|
||||||
// ff MoE
|
// ff MoE
|
||||||
struct ggml_tensor * ffn_gate_inp;
|
struct ggml_tensor * ffn_gate_inp;
|
||||||
struct ggml_tensor * ffn_gate_exp[8];
|
struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
|
||||||
struct ggml_tensor * ffn_down_exp[8];
|
struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
|
||||||
struct ggml_tensor * ffn_up_exp[8];
|
struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS];
|
||||||
|
|
||||||
// ff bias
|
// ff bias
|
||||||
struct ggml_tensor * ffn_down_b; // b2
|
struct ggml_tensor * ffn_down_b; // b2
|
||||||
@ -2458,6 +2468,16 @@ static void llm_load_hparams(
|
|||||||
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
||||||
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
||||||
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||||
|
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||||
|
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||||
|
|
||||||
|
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
|
||||||
|
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
||||||
|
if (hparams.n_expert > 0) {
|
||||||
|
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||||
|
}
|
||||||
|
|
||||||
// n_head_kv is optional, default to n_head
|
// n_head_kv is optional, default to n_head
|
||||||
hparams.n_head_kv = hparams.n_head;
|
hparams.n_head_kv = hparams.n_head;
|
||||||
@ -2889,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||||
|
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||||
|
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
||||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
@ -3046,10 +3068,16 @@ static void llm_load_tensors(
|
|||||||
layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
|
layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
|
||||||
|
|
||||||
if (layer.ffn_gate_inp == nullptr) {
|
if (layer.ffn_gate_inp == nullptr) {
|
||||||
|
GGML_ASSERT(hparams.n_expert == 0);
|
||||||
|
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||||
|
|
||||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
} else {
|
} else {
|
||||||
|
GGML_ASSERT(hparams.n_expert > 0);
|
||||||
|
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||||
|
|
||||||
// MoE branch
|
// MoE branch
|
||||||
for (int x = 0; x < 8; ++x) {
|
for (int x = 0; x < 8; ++x) {
|
||||||
layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
|
layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
|
||||||
@ -3073,7 +3101,7 @@ static void llm_load_tensors(
|
|||||||
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||||
} else {
|
} else {
|
||||||
vram_weights += ggml_nbytes(layer.ffn_gate_inp);
|
vram_weights += ggml_nbytes(layer.ffn_gate_inp);
|
||||||
for (int x = 0; x < 8; ++x) {
|
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
|
||||||
vram_weights +=
|
vram_weights +=
|
||||||
ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
|
ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
|
||||||
}
|
}
|
||||||
@ -4058,6 +4086,8 @@ struct llm_build_context {
|
|||||||
const int64_t n_head_kv;
|
const int64_t n_head_kv;
|
||||||
const int64_t n_embd_head;
|
const int64_t n_embd_head;
|
||||||
const int64_t n_embd_gqa;
|
const int64_t n_embd_gqa;
|
||||||
|
const int64_t n_expert;
|
||||||
|
const int64_t n_expert_used;
|
||||||
|
|
||||||
const float freq_base;
|
const float freq_base;
|
||||||
const float freq_scale;
|
const float freq_scale;
|
||||||
@ -4099,6 +4129,8 @@ struct llm_build_context {
|
|||||||
n_head_kv (hparams.n_head_kv),
|
n_head_kv (hparams.n_head_kv),
|
||||||
n_embd_head (hparams.n_embd_head()),
|
n_embd_head (hparams.n_embd_head()),
|
||||||
n_embd_gqa (hparams.n_embd_gqa()),
|
n_embd_gqa (hparams.n_embd_gqa()),
|
||||||
|
n_expert (hparams.n_expert),
|
||||||
|
n_expert_used (hparams.n_expert_used),
|
||||||
freq_base (cparams.rope_freq_base),
|
freq_base (cparams.rope_freq_base),
|
||||||
freq_scale (cparams.rope_freq_scale),
|
freq_scale (cparams.rope_freq_scale),
|
||||||
ext_factor (cparams.yarn_ext_factor),
|
ext_factor (cparams.yarn_ext_factor),
|
||||||
@ -4242,10 +4274,6 @@ struct llm_build_context {
|
|||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
cb(cur, "ffn_norm", il);
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
// TODO: param
|
|
||||||
const int n_experts = 8;
|
|
||||||
const int n_experts_per_tok = 2;
|
|
||||||
|
|
||||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||||
cb(logits, "ffn_moe_logits", il);
|
cb(logits, "ffn_moe_logits", il);
|
||||||
|
|
||||||
@ -4253,14 +4281,14 @@ struct llm_build_context {
|
|||||||
cb(probs, "ffn_moe_probs", il);
|
cb(probs, "ffn_moe_probs", il);
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
|
|
||||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||||
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
|
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
||||||
cb(weights, "ffn_moe_weights", il);
|
cb(weights, "ffn_moe_weights", il);
|
||||||
|
|
||||||
weights = ggml_reshape_2d(ctx0, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
||||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||||
@ -4271,18 +4299,13 @@ struct llm_build_context {
|
|||||||
// compute expert outputs
|
// compute expert outputs
|
||||||
ggml_tensor * moe_out = nullptr;
|
ggml_tensor * moe_out = nullptr;
|
||||||
|
|
||||||
for (int i = 0; i < n_experts_per_tok; ++i) {
|
for (int i = 0; i < n_expert_used; ++i) {
|
||||||
ggml_tensor * cur_expert;
|
ggml_tensor * cur_expert;
|
||||||
|
|
||||||
// TODO: fix
|
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
|
||||||
ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp;
|
|
||||||
ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
|
|
||||||
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
|
|
||||||
|
|
||||||
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur);
|
|
||||||
cb(cur_up, "ffn_moe_up", il);
|
cb(cur_up, "ffn_moe_up", il);
|
||||||
|
|
||||||
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur);
|
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
|
||||||
cb(cur_gate, "ffn_moe_gate", il);
|
cb(cur_gate, "ffn_moe_gate", il);
|
||||||
|
|
||||||
cur_gate = ggml_silu(ctx0, cur_gate);
|
cur_gate = ggml_silu(ctx0, cur_gate);
|
||||||
@ -4291,7 +4314,7 @@ struct llm_build_context {
|
|||||||
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
|
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
|
||||||
cb(cur_expert, "ffn_moe_gate_par", il);
|
cb(cur_expert, "ffn_moe_gate_par", il);
|
||||||
|
|
||||||
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||||
cb(cur_expert, "ffn_moe_down", il);
|
cb(cur_expert, "ffn_moe_down", il);
|
||||||
|
|
||||||
cur_expert = ggml_mul(ctx0, cur_expert,
|
cur_expert = ggml_mul(ctx0, cur_expert,
|
||||||
@ -8192,11 +8215,9 @@ static void llama_convert_tensor_internal(
|
|||||||
workers.clear();
|
workers.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_type get_k_quant_type(
|
static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
||||||
quantize_state_internal & qs,
|
|
||||||
ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
|
|
||||||
) {
|
|
||||||
const std::string name = ggml_get_name(tensor);
|
const std::string name = ggml_get_name(tensor);
|
||||||
|
|
||||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||||
const llm_arch arch = qs.model.arch;
|
const llm_arch arch = qs.model.arch;
|
||||||
const auto tn = LLM_TN(arch);
|
const auto tn = LLM_TN(arch);
|
||||||
@ -8230,7 +8251,18 @@ static ggml_type get_k_quant_type(
|
|||||||
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||||
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
|
if (qs.model.hparams.n_expert == 8) {
|
||||||
|
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||||
|
// TODO: explore better strategies
|
||||||
|
new_type = GGML_TYPE_Q8_0;
|
||||||
|
}
|
||||||
++qs.i_attention_wv;
|
++qs.i_attention_wv;
|
||||||
|
} else if (name.find("attn_k.weight") != std::string::npos) {
|
||||||
|
if (qs.model.hparams.n_expert == 8) {
|
||||||
|
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||||
|
// TODO: explore better strategies
|
||||||
|
new_type = GGML_TYPE_Q8_0;
|
||||||
|
}
|
||||||
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||||
|
Loading…
Reference in New Issue
Block a user