llm : cleanup + comments

This commit is contained in:
Georgi Gerganov 2023-11-01 20:08:02 +02:00
parent 78186f4009
commit a8796f9609
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3435,7 +3435,7 @@ struct llm_build_context {
const int64_t n_embd; const int64_t n_embd;
const int64_t n_layer; const int64_t n_layer;
const int64_t n_ctx; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_head; const int64_t n_head;
const int64_t n_head_kv; const int64_t n_head_kv;
const int64_t n_embd_head; const int64_t n_embd_head;
@ -3447,8 +3447,8 @@ struct llm_build_context {
const float norm_rms_eps; const float norm_rms_eps;
const int32_t n_tokens; const int32_t n_tokens;
const int32_t n_kv; const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
const int32_t kv_head; const int32_t kv_head; // index of where we store new KV data in the cache
const bool do_rope_shift; const bool do_rope_shift;
@ -3457,7 +3457,6 @@ struct llm_build_context {
llama_buffer & buf_compute; llama_buffer & buf_compute;
struct ggml_context * ctx0 = nullptr; struct ggml_context * ctx0 = nullptr;
struct ggml_cgraph * gf0 = nullptr;
// TODO: consider making the entire interface noexcept // TODO: consider making the entire interface noexcept
llm_build_context( llm_build_context(
@ -3500,8 +3499,6 @@ struct llm_build_context {
}; };
ctx0 = ggml_init(params); ctx0 = ggml_init(params);
gf0 = ggml_new_graph(ctx0);
} }
void free() { void free() {
@ -3511,8 +3508,9 @@ struct llm_build_context {
} }
} }
public:
struct ggml_cgraph * build_llama() { struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
GGML_ASSERT(n_embd_head == hparams.n_rot); GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur; struct ggml_tensor * cur;
@ -3535,7 +3533,7 @@ public:
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf0, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3565,7 +3563,7 @@ public:
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, cur, hparams, kv_self, cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -3609,13 +3607,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_baichuan() { struct ggml_cgraph * build_baichuan() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -3636,7 +3635,7 @@ public:
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf0, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3673,7 +3672,7 @@ public:
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
// apply ALiBi for 13B model // apply ALiBi for 13B model
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f; const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
@ -3720,13 +3719,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_falcon() { struct ggml_cgraph * build_falcon() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -3747,7 +3747,7 @@ public:
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf0, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3793,7 +3793,7 @@ public:
Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self, cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -3835,13 +3835,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_starcoder() { struct ggml_cgraph * build_starcoder() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * pos; struct ggml_tensor * pos;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -3892,7 +3893,7 @@ public:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, cur, hparams, kv_self, cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
@ -3933,13 +3934,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_persimmon() { struct ggml_cgraph * build_persimmon() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
const int64_t n_rot = n_embd_head / 2; const int64_t n_rot = n_embd_head / 2;
struct ggml_tensor * cur; struct ggml_tensor * cur;
@ -3959,7 +3961,7 @@ public:
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf0, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -4095,7 +4097,7 @@ public:
); );
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
// TODO: not tested, could be broken // TODO: not tested, could be broken
cur = llm_build_kqv(ctx0, Q, hparams, kv_self, cur = llm_build_kqv(ctx0, Q, hparams, kv_self,
@ -4140,13 +4142,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_refact() { struct ggml_cgraph * build_refact() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4186,7 +4189,7 @@ public:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -4230,13 +4233,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_bloom() { struct ggml_cgraph * build_bloom() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4282,7 +4286,7 @@ public:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
@ -4323,13 +4327,14 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
public:
struct ggml_cgraph * build_mpt() { struct ggml_cgraph * build_mpt() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4375,7 +4380,7 @@ public:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
llm_build_kv_store(ctx0, hparams, kv_self, gf0, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -4421,9 +4426,9 @@ public:
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf0, cur); ggml_build_forward_expand(gf, cur);
return gf0; return gf;
} }
}; };