From 4b5f3cd6bf51d7974c7480fa23a44563b0a785a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 19 Sep 2023 17:00:42 +0300 Subject: [PATCH] parallel : process system prompt once + configurable paramters + llama API --- common/common.cpp | 20 ++++- common/common.h | 3 + examples/llama-bench/llama-bench.cpp | 4 +- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 101 ++++++++++++++++------ examples/perplexity/perplexity.cpp | 6 +- examples/speculative/speculative.cpp | 6 +- llama.cpp | 121 ++++++++++++++++----------- llama.h | 15 +++- 9 files changed, 187 insertions(+), 93 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 52387e2a6..8bd006960 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_chunks = std::stoi(argv[i]); + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + } else if (arg == "-ns" || arg == "--sequences") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_sequences = std::stoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.multiline_input = true; } else if (arg == "--simple-io") { params.simple_io = true; + } else if (arg == "--hot-plug") { + params.hot_plug = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -659,6 +673,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); + printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); + printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); + printf(" --hot-plug enable hot-plugging of new sequences for decoding (default: disabled)\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -781,7 +798,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); - llama_kv_cache_rm_tokens(lctx, -1, -1); + llama_kv_cache_tokens_rm(lctx, -1, -1); llama_reset_timings(lctx); } @@ -1253,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "hot_plug: %s # default: false\n", params.hot_plug ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", params.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); diff --git a/common/common.h b/common/common.h index 945403274..9269a5d36 100644 --- a/common/common.h +++ b/common/common.h @@ -43,6 +43,8 @@ struct gpt_params { int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors @@ -108,6 +110,7 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool hot_plug = false; // hot-plug new sequences for decoding bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 7a3d3b97f..4d23db5ee 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -977,7 +977,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_rm_tokens(ctx, -1, -1); + llama_kv_cache_tokens_rm(ctx, -1, -1); // warmup run if (t.n_prompt > 0) { @@ -988,7 +988,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_rm_tokens(ctx, -1, -1); + llama_kv_cache_tokens_rm(ctx, -1, -1); uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9c5f2746a..1ed543cbc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -505,8 +505,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index c6e8d9f5c..20918fd31 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -35,7 +35,7 @@ User: Hello, what is the temperature outside? Assistant: It is 72 degrees Fahrenheit. User: What is the definition of a prime number? Assistant: A prime number is a number that is divisible only by itself and 1. -User: )"; +User:)"; static std::vector k_prompts = { "What is the meaning of life?", @@ -70,7 +70,7 @@ struct client { std::string prompt; std::string response; - std::vector last_tokens; + std::vector tokens_prev; }; int main(int argc, char ** argv) { @@ -80,13 +80,14 @@ int main(int argc, char ** argv) { return 1; } - const int n_clients = 8; - - // insert new requests as soon as the previous one is done - const bool hot_plug = true; + // number of simultaneous "clients" to simulate + const int32_t n_clients = params.n_parallel; // requests to simulate - const int32_t n_seq = 128; + const int32_t n_seq = params.n_sequences; + + // insert new requests as soon as the previous one is done + const bool hot_plug = params.hot_plug; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("parallel", "log")); @@ -114,13 +115,17 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.last_tokens.resize(n_ctx); - std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); + client.tokens_prev.resize(n_ctx); + std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); } std::vector candidates; candidates.reserve(n_vocab); + std::vector tokens_system; + tokens_system = ::llama_tokenize(ctx, k_system, true); + const uint32_t n_tokens_system = tokens_system.size(); + llama_seq_id g_seq_id = 0; std::vector batch_token; @@ -134,6 +139,44 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); + LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); + LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system); + LOG_TEE("\n"); + + { + LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); + + batch_pos.clear(); + batch_seq_id.clear(); + + for (size_t i = 0; i < n_tokens_system; ++i) { + batch_pos.push_back(i); + batch_seq_id.push_back(0); + } + + llama_batch batch = { + n_tokens_system, + tokens_system.data(), + nullptr, + batch_pos.data(), + batch_seq_id.data(), + nullptr, + 0, 0, 0, // unused + }; + + if (llama_decode(ctx, batch, params.n_threads) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // assign the system KV cachce to all parallel sequences + for (int32_t i = 1; i < n_clients; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + } + + LOG_TEE("\n"); + } + while (true) { uint32_t n_tokens = 0; @@ -148,7 +191,7 @@ int main(int argc, char ** argv) { } batch_token.push_back(client.sampled); - batch_pos.push_back(client.n_decoded + client.n_prompt); + batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded); batch_seq_id.push_back(client.seq_id); batch_logits.push_back(true); batch_clients.push_back(&client); @@ -158,34 +201,36 @@ int main(int argc, char ** argv) { if (batch_token.empty()) { // all sequences have ended - clear the entire KV cache - llama_kv_cache_rm_tokens(ctx, -1, -1); + for (int i = 0; i < n_clients; ++i) { + llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); + } } if (hot_plug || batch_token.empty()) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { - client.seq_id = g_seq_id; + client.seq_id = client.id; client.t_start_prompt = ggml_time_us(); client.t_start_gen = 0; client.input = k_prompts[rand() % k_prompts.size()]; - client.prompt = k_system + client.input + "\nAssistant:"; + client.prompt = client.input + "\nAssistant:"; client.response = ""; - std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); + std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); - std::vector prompt_tokens; - prompt_tokens = ::llama_tokenize(ctx, client.prompt, true); + std::vector tokens_prompt; + tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); - for (size_t i = 0; i < prompt_tokens.size(); ++i) { - batch_token.push_back(prompt_tokens[i]); - batch_pos.push_back(i); + for (size_t i = 0; i < tokens_prompt.size(); ++i) { + batch_token.push_back(tokens_prompt[i]); + batch_pos.push_back(i + n_tokens_system); batch_seq_id.push_back(client.seq_id); batch_clients.push_back(&client); batch_logits.push_back(false); } batch_logits.back() = true; - client.n_prompt = prompt_tokens.size(); + client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; client.i_batch = batch_token.size() - 1; @@ -217,9 +262,10 @@ int main(int argc, char ** argv) { 0, 0, 0, // unused }; - if (llama_decode(ctx, batch, params.n_threads)) { - if (n_batch == 1) { - LOG_TEE("%s : failed to decode batch\n", __func__); + const int ret = llama_decode(ctx, batch, params.n_threads); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); return 1; } @@ -242,7 +288,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i); + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -251,8 +297,8 @@ int main(int argc, char ** argv) { } // remember which tokens were sampled - used for repetition penalties during sampling - client.last_tokens.erase(client.last_tokens.begin()); - client.last_tokens.push_back(id); + client.tokens_prev.erase(client.tokens_prev.begin()); + client.tokens_prev.push_back(id); const std::string token_str = llama_token_to_piece(ctx, id); client.response += token_str; @@ -271,7 +317,8 @@ int main(int argc, char ** argv) { client.response = client.response.substr(0, pos); } - llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx); + // delete only the generated part of the sequence, i.e. keep the system prompt in the cache + llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx); const auto t_main_end = ggml_time_us(); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8386a3d16..be87011d1 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -207,7 +207,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_keep_seq(ctx, -1); + llama_kv_cache_tokens_rm(ctx, -1, -1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_keep_seq(ctx, -1); + llama_kv_cache_tokens_rm(ctx, -1, -1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -568,7 +568,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_keep_seq(ctx, -1); + llama_kv_cache_tokens_rm(ctx, -1, -1); auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); if (logits.empty()) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index ea628211b..df93c9cd4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); ++n_past_dft; @@ -257,7 +257,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); ++n_past_cur; @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); ++n_past_tgt; diff --git a/llama.cpp b/llama.cpp index 089b87f56..12b8c49d0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1328,7 +1328,7 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { +static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { if (c0 < 0) c0 = 0; if (c1 < 0) c1 = cache.size; @@ -1338,7 +1338,7 @@ static void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, } } -static void llama_kv_cache_rm_seq( +static void llama_kv_cache_seq_rm( struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -1353,7 +1353,20 @@ static void llama_kv_cache_rm_seq( } } -static void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { for (uint32_t i = 0; i < cache.size; ++i) { if (!cache.cells[i].has_seq_id(seq_id)) { cache.cells[i].pos = -1; @@ -1362,7 +1375,7 @@ static void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id } } -static void llama_kv_cache_shift_seq( +static void llama_kv_cache_seq_shift( struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -4019,7 +4032,11 @@ static struct ggml_cgraph * llama_build_graph( // - batch: batch to evaluate // - n_threads: number of threads to use // -static bool llama_decode_internal( +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_internal( llama_context & lctx, llama_batch batch, int n_threads) { @@ -4027,7 +4044,7 @@ static bool llama_decode_internal( if (n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); - return false; + return -1; } GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -4079,7 +4096,7 @@ static bool llama_decode_internal( kv_self.head = 0; if (!llama_kv_cache_find_slot(kv_self, batch)) { - return false; + return 1; } // a heuristic, to avoid attending the full cache if it is not yet utilized @@ -4203,7 +4220,14 @@ static bool llama_decode_internal( lctx.n_p_eval += n_tokens; } - return true; + // get a more accurate load time, upon first eval + // TODO: fix this + if (!lctx.has_evaluated_once) { + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; + lctx.has_evaluated_once = true; + } + + return 0; } // @@ -6920,20 +6944,24 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->kv_self.head; } -void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) { - llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1); +void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { + llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); } -void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1); +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } -void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_keep_seq(ctx->kv_self, seq_id); +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta); +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_seq_keep(ctx->kv_self, seq_id); +} + +void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } // Returns the *maximum* size of the state @@ -7330,21 +7358,18 @@ int llama_eval( uint32_t n_tokens, int n_past, int n_threads) { - llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { - //LLAMA_LOG_ERROR("%s: failed to decode\n", __func__); - return 1; + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads); + if (ret != 0) { + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; } - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; - } - - return 0; + return ret; } int llama_eval_embd( @@ -7353,23 +7378,20 @@ int llama_eval_embd( uint32_t n_tokens, int n_past, int n_threads) { - llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; - if (!llama_decode_internal(*ctx, batch, n_threads)) { - //LLAMA_LOG_ERROR("%s: failed to decode\n", __func__); - return 1; + const int ret = llama_decode_internal(*ctx, batch, n_threads); + if (ret != 0) { + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; } - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; - } - - return 0; + return ret; } struct llama_batch llama_batch_get_one( @@ -7394,19 +7416,16 @@ int llama_decode( struct llama_context * ctx, struct llama_batch batch, int n_threads) { - if (!llama_decode_internal(*ctx, batch, n_threads)) { - //LLAMA_LOG_ERROR("%s: failed to decode\n", __func__); - return 1; + const int ret = llama_decode_internal(*ctx, batch, n_threads); + if (ret != 0) { + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; } - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; - } - - return 0; + return ret; } float * llama_get_logits(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index e4f02c978..2f344eb14 100644 --- a/llama.h +++ b/llama.h @@ -322,17 +322,20 @@ extern "C" { "avoid using this, it will be removed in the future, instead - count the tokens in user code"); // Remove all tokens data of cells in [c0, c1) - LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1); + LLAMA_API void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + LLAMA_API void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly - LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); // // State / sessions @@ -391,6 +394,10 @@ extern "C" { llama_pos pos_0, llama_seq_id seq_id); + // Positive return values does not mean a fatal error, but rather a warning. + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // < 0 - error LLAMA_API int llama_decode( struct llama_context * ctx, struct llama_batch batch,