From 1fb5d4fdee54cf48d915005b5125c1650c35a909 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 17 Jul 2024 14:48:09 -0400 Subject: [PATCH] llama : apply suggestions Co-authored-by: Georgi Gerganov --- src/llama.cpp | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4fc3359b9..738f6d3af 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3018,7 +3018,7 @@ struct llama_sbatch { return; } std::sort(ids.begin(), ids.end(), - [batch](size_t a, size_t b) { + [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; // sort by seq_id, then by pos @@ -3050,7 +3050,6 @@ struct llama_sbatch { if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { for (size_t i = 0; i < n_tokens; ++i) { const size_t bi = ids[i]; - const size_t s_len = seq.size(); const int32_t n_seqs = batch.n_seq_id[bi]; llama_seq_id * seq_ids = batch.seq_id[bi]; if (last_seq != nullptr) { @@ -3067,7 +3066,7 @@ struct llama_sbatch { } llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; seq.push_back(new_seq); - last_seq = &seq[s_len]; + last_seq = &seq.back(); } } else { llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; @@ -15089,8 +15088,8 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // For now, only use equal splits for recurrent model architectures - llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); - const uint32_t n_tokens = u_batch.n_tokens; + llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); + const uint32_t n_tokens = ubatch.n_tokens; // count the outputs in this u_batch { @@ -15099,9 +15098,9 @@ static int llama_decode_internal( if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { - GGML_ASSERT(u_batch.output); + GGML_ASSERT(ubatch.output); for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += (int32_t) (u_batch.output[i] != 0); + n_outputs_new += (int32_t) (ubatch.output[i] != 0); } } @@ -15122,7 +15121,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(kv_self, ubatch)) { return 1; } @@ -15141,7 +15140,7 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false); // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -15166,7 +15165,7 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); - llama_set_inputs(lctx, u_batch); + llama_set_inputs(lctx, ubatch); llama_graph_compute(lctx, gf, n_threads); @@ -15229,7 +15228,7 @@ static int llama_decode_internal( embd_seq_out.clear(); for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = u_batch.seq_id[i][0]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; }