From 2b8830af7153a75fba5a899a8b331389025d0d03 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Sep 2023 17:48:25 +0300 Subject: [PATCH] examples : do not eval prompt 2 times (close #3348) --- examples/batched/batched.cpp | 29 ++++++++++++++++------------- examples/simple/simple.cpp | 16 ++++++++-------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 08b082b33..4dd1d553d 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" +#include #include #include #include @@ -42,7 +43,9 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301) + ctx_params.n_batch = std::max(n_len, n_parallel); + // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); @@ -66,11 +69,11 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { - LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } @@ -88,7 +91,7 @@ int main(int argc, char ** argv) { // create a llama_batch with size 512 // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(512, 0); + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); // evaluate the initial prompt batch.n_tokens = tokens_list.size(); @@ -133,12 +136,6 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { - // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } - // prepare the next batch batch.n_tokens = 0; @@ -149,8 +146,8 @@ int main(int argc, char ** argv) { continue; } - auto n_vocab = llama_n_vocab(ctx); - auto logits = llama_get_logits_ith(ctx, i_batch[i]); + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits_ith(ctx, i_batch[i]); std::vector candidates; candidates.reserve(n_vocab); @@ -178,7 +175,7 @@ int main(int argc, char ** argv) { i_batch[i] = -1; LOG_TEE("\n"); if (n_parallel > 1) { - LOG_TEE("%s: stream %d finished", __func__, i); + LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); } continue; @@ -211,6 +208,12 @@ int main(int argc, char ** argv) { } n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } } LOG_TEE("\n"); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 2acdc7273..1616a4a75 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -110,16 +110,10 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { - // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } - // sample the next token { - auto n_vocab = llama_n_vocab(ctx); - auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); std::vector candidates; candidates.reserve(n_vocab); @@ -158,6 +152,12 @@ int main(int argc, char ** argv) { } n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } } LOG_TEE("\n");