llama : fix defrag bugs + add parameter (#5735)

* llama : fix defrag bugs + enable by default

ggml-ci

* llama : add defrag_thold parameter

ggml-ci

* llama : cont

* llama : disable log message

ggml-ci

* llama : fix graph size check during defrag
This commit is contained in:
Georgi Gerganov 2024-02-27 14:35:51 +02:00 committed by GitHub
parent cbbd1efa06
commit 9d533a77d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 30 deletions

View File

@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--samplers") { } else if (arg == "--samplers") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_k = kv_cache_type_from_str(params.cache_type_k);

View File

@ -75,6 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

View File

@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx); //llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx); llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx); //llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx); llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

View File

@ -1641,6 +1641,7 @@ struct llama_cparams {
float yarn_attn_factor; float yarn_attn_factor;
float yarn_beta_fast; float yarn_beta_fast;
float yarn_beta_slow; float yarn_beta_slow;
float defrag_thold;
bool mul_mat_q; bool mul_mat_q;
bool offload_kqv; bool offload_kqv;
@ -5117,16 +5118,16 @@ struct llm_build_context {
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) { struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
for (int i = 0; i < n_kv; ++i) { for (uint32_t i = 0; i < ids.size(); ++i) {
const int id = ids[i]; const uint32_t id = ids[i];
if (i == id || id == n_kv) { if (i == id || id == ids.size()) {
continue; continue;
} }
int nm = 1; uint32_t nm = 1;
while (i + nm < n_kv && (int) ids[i + nm] == id + nm) { while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++; nm++;
} }
@ -5158,6 +5159,8 @@ struct llm_build_context {
i += nm - 1; i += nm - 1;
} }
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
return gf; return gf;
} }
@ -7938,6 +7941,8 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data(); batch.seq_id = seq_id_arr.data();
} }
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) { if (kv_self.head > kv_self.used + 2*n_tokens) {
@ -7956,8 +7961,6 @@ static int llama_decode_internal(
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
llama_kv_cache_update(&lctx);
ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8007,6 +8010,18 @@ static int llama_decode_internal(
} }
} }
// decide if we need to defrag the kv cache
if (cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
llama_kv_cache_defrag(kv_self);
}
}
#ifdef GGML_PERF #ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes) // print timing information per ggml operation (for debugging purposes)
// requires GGML_PERF to be defined // requires GGML_PERF to be defined
@ -8098,12 +8113,16 @@ static int llama_decode_internal(
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self; auto & kv_self = lctx.kv_self;
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self); const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
const uint32_t n_used = kv_self.used; const uint32_t n_used = kv_self.used;
assert(n_used <= n_kv); assert(n_used <= n_kv);
const int64_t t_start = ggml_time_us(); //const int64_t t_start = ggml_time_us();
// number of cells moved // number of cells moved
uint32_t n_moves = 0; uint32_t n_moves = 0;
@ -8127,15 +8146,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// found a hole - fill it with data from the end of the cache // found a hole - fill it with data from the end of the cache
// determine the size of the hole
uint32_t nh = 1; uint32_t nh = 1;
// determine the size of the hole
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) { while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
nh++; nh++;
} }
// starting from the end, find nh non-empty cells // each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
//
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
// the graph is too big, we cannot move more cells
break;
}
uint32_t nf = 0; uint32_t nf = 0;
uint32_t is = n_kv - 1; uint32_t is = n_kv - 1;
// starting from the end, find nh non-empty cells
for (; is > i0; --is) { for (; is > i0; --is) {
const auto & cell1 = kv_self.cells[is]; const auto & cell1 = kv_self.cells[is];
@ -8156,11 +8186,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
nf = 0; nf = 0;
uint32_t i1 = is;
// are we moving a continuous block of memory?
bool cont = false;
// go back and move the nf cells to the hole // go back and move the nf cells to the hole
for (uint32_t i1 = is; i1 < n_kv; ++i1) { for (; i1 < n_kv; ++i1) {
const auto & cell1 = kv_self.cells[i1]; auto & cell1 = kv_self.cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) { if (cell1.is_empty() || ids[i1] != n_kv) {
cont = false;
continue; continue;
} }
@ -8170,11 +8206,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// move the cell meta data // move the cell meta data
kv_self.cells[i0 + nf] = cell1; kv_self.cells[i0 + nf] = cell1;
// clear the old cell and move the head there
cell1 = llama_kv_cell();
kv_self.head = n_used;
if (!cont) {
n_moves++; n_moves++;
nf++; cont = true;
} }
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh); nf++;
if (nf == nh) {
break;
}
}
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1; i0 += nh - 1;
} }
@ -8183,15 +8231,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
return; return;
} }
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves); //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
kv_self.head = n_used; //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
kv_self.used = n_used;
// zero the rest of the cells
for (uint32_t i = n_used; i < n_kv; ++i) {
kv_self.cells[i] = llama_kv_cell();
}
#if 0 #if 0
// CPU defrag // CPU defrag
@ -8203,9 +8245,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// likely not worth the effort, as we have ggml_graph based defrag // likely not worth the effort, as we have ggml_graph based defrag
// //
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
@ -8274,9 +8313,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
llama_graph_compute(lctx, gf, lctx.cparams.n_threads); llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
#endif #endif
const int64_t t_end = ggml_time_us(); //const int64_t t_end = ggml_time_us();
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
} }
static void llama_kv_cache_update_internal(struct llama_context & lctx) { static void llama_kv_cache_update_internal(struct llama_context & lctx) {
@ -11670,6 +11709,7 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f, /*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0, /*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr, /*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr, /*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,
@ -11834,6 +11874,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_attn_factor = params.yarn_attn_factor; cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.mul_mat_q = params.mul_mat_q; cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv; cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling; cparams.do_pooling = params.do_pooling;
@ -12035,7 +12076,7 @@ struct llama_context * llama_new_context_with_model(
} }
// buffer used to store the computation graph and the tensor meta data // buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES); ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);

View File

@ -245,6 +245,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;