From addae65fd44d362995acd8c05b99c3351c214df8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 Sep 2023 10:46:18 +0300 Subject: [PATCH] llama : improve llama_batch API + simplify parallel example --- examples/parallel/parallel.cpp | 107 +++++++++++++-------------- examples/perplexity/perplexity.cpp | 2 +- examples/simple/simple.cpp | 8 +- examples/speculative/speculative.cpp | 2 +- llama.cpp | 30 +++++++- llama.h | 32 +++++--- 6 files changed, 111 insertions(+), 70 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 4af4d2cd2..e252b0f53 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -127,11 +127,7 @@ int main(int argc, char ** argv) { llama_seq_id g_seq_id = 0; - std::vector batch_token; - std::vector batch_pos; - std::vector batch_seq_id; - std::vector batch_logits; - std::vector batch_clients; + llama_batch batch = llama_batch_init(params.n_batch, 0); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -146,24 +142,15 @@ int main(int argc, char ** argv) { { LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); - batch_pos.clear(); - batch_seq_id.clear(); + batch.n_tokens = n_tokens_system; - for (size_t i = 0; i < n_tokens_system; ++i) { - batch_pos.push_back(i); - batch_seq_id.push_back(0); + for (uint32_t i = 0; i < batch.n_tokens; ++i) { + batch.token[i] = tokens_system[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; } - llama_batch batch = { - n_tokens_system, - tokens_system.data(), - nullptr, - batch_pos.data(), - batch_seq_id.data(), - nullptr, - 0, 0, 0, // unused - }; - if (llama_decode(ctx, batch, params.n_threads) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; @@ -180,63 +167,72 @@ int main(int argc, char ** argv) { LOG_TEE("Processing requests ...\n\n"); while (true) { - uint32_t n_tokens = 0; - - batch_token.clear(); - batch_pos.clear(); - batch_seq_id.clear(); - batch_logits.clear(); + batch.n_tokens = 0; + // decode any currently ongoing sequences for (auto & client : clients) { if (client.seq_id == -1) { continue; } - batch_token.push_back(client.sampled); - batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded); - batch_seq_id.push_back(client.id); - batch_logits.push_back(true); - batch_clients.push_back(&client); + batch.token [batch.n_tokens] = client.sampled; + batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; + batch.seq_id[batch.n_tokens] = client.id; + batch.logits[batch.n_tokens] = true; + client.n_decoded += 1; - client.i_batch = batch_token.size() - 1; + client.i_batch = batch.n_tokens; + + batch.n_tokens += 1; } - if (batch_token.empty()) { + if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 0; i < n_clients; ++i) { llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); } } - if (cont_batching || batch_token.empty()) { + // insert new sequences for decoding + if (cont_batching || batch.n_tokens == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; + client.t_start_prompt = ggml_time_us(); client.t_start_gen = 0; - client.input = k_prompts[rand() % k_prompts.size()]; - client.prompt = client.input + "\nAssistant:"; + client.input = k_prompts[rand() % k_prompts.size()]; + client.prompt = client.input + "\nAssistant:"; client.response = ""; + std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); std::vector tokens_prompt; tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - batch_token.push_back(tokens_prompt[i]); - batch_pos.push_back(i + n_tokens_system); - batch_seq_id.push_back(client.id); - batch_clients.push_back(&client); - batch_logits.push_back(false); + batch.token [batch.n_tokens] = tokens_prompt[i]; + batch.pos [batch.n_tokens] = i + n_tokens_system; + batch.seq_id[batch.n_tokens] = client.id; + batch.logits[batch.n_tokens] = false; + batch.n_tokens += 1; + } + + // extract the logits only for the last token + if (batch.n_tokens > 0) { + batch.logits[batch.n_tokens - 1] = true; } - batch_logits.back() = true; client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = batch_token.size() - 1; + client.i_batch = batch.n_tokens - 1; + + LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); g_seq_id += 1; + + // insert new requests one-by-one //if (cont_batching) { // break; //} @@ -244,34 +240,35 @@ int main(int argc, char ** argv) { } } - if (batch_token.empty()) { + if (batch.n_tokens == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) { - n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i)); + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - llama_batch batch = { + llama_batch batch_view = { n_tokens, - batch_token.data() + i, + batch.token + i, nullptr, - batch_pos.data() + i, - batch_seq_id.data() + i, - batch_logits.data() + i, + batch.pos + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; - const int ret = llama_decode(ctx, batch, params.n_threads); + const int ret = llama_decode(ctx, batch_view, params.n_threads); if (ret != 0) { if (n_batch == 1 || ret < 0) { - LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); return 1; } - LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); n_cache_miss += 1; @@ -357,6 +354,8 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); + llama_batch_free(batch); + llama_free(ctx); llama_free_model(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index be87011d1..190634167 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -419,7 +419,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static std::vector hellaswag_evaluate_tokens( - llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, int n_vocab, int n_thread + llama_context * ctx, std::vector & tokens, int n_past, int n_batch, int n_vocab, int n_thread ) { std::vector result; result.reserve(tokens.size() * n_vocab); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 593949c87..8a9a1bf54 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -10,10 +10,12 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); + printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]); return 1 ; } + int n_parallel = 1; + if (argc >= 2) { params.model = argv[1]; } @@ -22,6 +24,10 @@ int main(int argc, char ** argv) { params.prompt = argv[2]; } + if (argc >= 4) { + n_parallel = std::atoi(argv[3]); + } + if (params.prompt.empty()) { params.prompt = "Hello my name is"; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index df93c9cd4..2445d78dc 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -134,7 +134,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); diff --git a/llama.cpp b/llama.cpp index f38a033a5..f47d9b598 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7356,7 +7356,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi int llama_eval( struct llama_context * ctx, - const llama_token * tokens, + llama_token * tokens, uint32_t n_tokens, int n_past, int n_threads) { @@ -7376,7 +7376,7 @@ int llama_eval( int llama_eval_embd( struct llama_context * ctx, - const float * embd, + float * embd, uint32_t n_tokens, int n_past, int n_threads) { @@ -7397,7 +7397,7 @@ int llama_eval_embd( } struct llama_batch llama_batch_get_one( - const llama_token * tokens, + llama_token * tokens, uint32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { @@ -7414,6 +7414,30 @@ struct llama_batch llama_batch_get_one( }; } +struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) { + llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.seq_id) free(batch.seq_id); + if (batch.logits) free(batch.logits); +} + int llama_decode( struct llama_context * ctx, struct llama_batch batch, diff --git a/llama.h b/llama.h index 2f344eb14..3a46e1ea0 100644 --- a/llama.h +++ b/llama.h @@ -70,11 +70,11 @@ extern "C" { typedef struct llama_batch { uint32_t n_tokens; - const llama_token * token; - const float * embd; - const llama_pos * pos; - const llama_seq_id * seq_id; - const int8_t * logits; // if 0, do not extract logits for that token + llama_token * token; + float * embd; + llama_pos * pos; + llama_seq_id * seq_id; + int8_t * logits; // if 0, do not extract logits for that token // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -84,7 +84,7 @@ extern "C" { llama_pos all_pos_0; // used if pos == NULL llama_pos all_pos_1; // used if pos == NULL llama_seq_id all_seq_id; // used if seq_id == NULL - } llama_seq; + } llama_batch; enum llama_log_level { LLAMA_LOG_LEVEL_ERROR = 2, @@ -366,34 +366,46 @@ extern "C" { // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls // Returns 0 on success + // DEPRECATED: use llama_decode() instead LLAMA_API DEPRECATED(int llama_eval( struct llama_context * ctx, - const llama_token * tokens, + llama_token * tokens, uint32_t n_tokens, int n_past, int n_threads), "please use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. + // DEPRECATED: use llama_decode() instead LLAMA_API DEPRECATED(int llama_eval_embd( struct llama_context * ctx, - const float * embd, + float * embd, uint32_t n_tokens, int n_past, int n_threads), "please use llama_decode() instead"); // Return batch for single sequence of tokens starting at pos_0 - // If pos_0 == 0, the clear_kv flag will be auto set to true // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // LLAMA_API struct llama_batch llama_batch_get_one( - const llama_token * tokens, + llama_token * tokens, uint32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id); + // Allocates a batch of tokens on the heap + // The batch needs to be freed with llama_batch_free() + // If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)