From fa0e67782086a19d47f327d51a4be14a45e4b891 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 19 Sep 2023 00:24:13 +0300 Subject: [PATCH] llama : extend batch API to select which logits to output --- examples/embd-input/embd-input-lib.cpp | 2 +- examples/parallel/parallel.cpp | 34 ++++++++++++++++++++++++-- llama.cpp | 14 +++++++++-- llama.h | 2 +- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 344a8b2c3..339612cce 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; if (llama_decode(ctx, batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index a8b6f629d..6e68c5afc 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -82,6 +82,9 @@ int main(int argc, char ** argv) { const int n_clients = 4; + // insert new requests as soon as the previous one is done + const bool hot_swap = true; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("parallel", "log")); LOG_TEE("Log start\n"); @@ -121,14 +124,23 @@ int main(int argc, char ** argv) { std::vector batch_token; std::vector batch_pos; std::vector batch_seq_id; + std::vector batch_logits; std::vector batch_clients; - while (true) { + int32_t n_total_prompt = 0; + int32_t n_total_gen = 0; + + float t_avg = 0.0f; + + const int32_t n_seq = 128; + + while (g_seq_id < n_seq + n_clients) { uint32_t n_tokens = 0; batch_token.clear(); batch_pos.clear(); batch_seq_id.clear(); + batch_logits.clear(); for (auto & client : clients) { if (client.seq_id == -1) { @@ -138,6 +150,7 @@ int main(int argc, char ** argv) { batch_token.push_back(client.sampled); batch_pos.push_back(client.n_decoded); batch_seq_id.push_back(client.seq_id); + batch_logits.push_back(true); batch_clients.push_back(&client); client.n_decoded += 1; client.i_batch = batch_token.size() - 1; @@ -146,7 +159,9 @@ int main(int argc, char ** argv) { if (batch_token.empty()) { // all sequences have ended - clear the entire KV cache llama_kv_cache_rm_tokens(ctx, -1, -1); + } + if (hot_swap || batch_token.empty()) { for (auto & client : clients) { if (client.seq_id == -1) { client.seq_id = g_seq_id; @@ -166,7 +181,10 @@ int main(int argc, char ** argv) { batch_pos.push_back(i); batch_seq_id.push_back(client.seq_id); batch_clients.push_back(&client); + batch_logits.push_back(false); } + batch_logits.back() = true; + client.n_prompt = prompt_tokens.size(); client.n_decoded = prompt_tokens.size(); client.i_batch = batch_token.size() - 1; @@ -186,6 +204,7 @@ int main(int argc, char ** argv) { nullptr, batch_pos.data() + i, batch_seq_id.data() + i, + batch_logits.data() + i, 0, 0, 0, // unused }; @@ -232,14 +251,20 @@ int main(int argc, char ** argv) { const auto t_main_end = ggml_time_us(); - printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n", + printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n", client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, + (t_main_end - client.t_start_prompt) / 1e6, (double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6, (double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6, (double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6, ::trim(client.input).c_str(), ::trim(client.response).c_str()); + n_total_prompt += client.n_prompt; + n_total_gen += client.n_decoded - client.n_prompt; + + t_avg += (t_main_end - client.t_start_prompt) / 1e6; + client.seq_id = -1; } @@ -248,6 +273,11 @@ int main(int argc, char ** argv) { } } + LOG_TEE("\n\n"); + LOG_TEE("Total prompt tokens: %d\n", n_total_prompt); + LOG_TEE("Total gen tokens: %d\n", n_total_gen); + LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq); + LOG_TEE("\n\n"); llama_print_timings(ctx); diff --git a/llama.cpp b/llama.cpp index 3a4a2b6ac..3e54fed7c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4140,7 +4140,16 @@ static bool llama_eval_internal( if (lctx.logits_all) { logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); + if (batch.logits) { + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); + } + } else { + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); + } } else { // return result for just the last token logits_out.resize(n_vocab); @@ -7318,7 +7327,7 @@ int llama_eval_embd( int n_threads) { llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; if (!llama_eval_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); @@ -7346,6 +7355,7 @@ struct llama_batch llama_batch_get_one( /*embd =*/ nullptr, /*pos =*/ nullptr, /*seq_id =*/ nullptr, + /*logits =*/ nullptr, /*all_pos_0 =*/ pos_0, /*all_pos_1 =*/ 1, /*all_seq_id =*/ seq_id, diff --git a/llama.h b/llama.h index 4a5f2e3bf..e4f02c978 100644 --- a/llama.h +++ b/llama.h @@ -70,11 +70,11 @@ extern "C" { typedef struct llama_batch { uint32_t n_tokens; - // TODO: not sure about these consts - might just get in the way all the time with no benefit const llama_token * token; const float * embd; const llama_pos * pos; const llama_seq_id * seq_id; + const int8_t * logits; // if 0, do not extract logits for that token // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below