diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 6e68c5afc..3c3fe6ddb 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -80,10 +80,13 @@ int main(int argc, char ** argv) { return 1; } - const int n_clients = 4; + const int n_clients = 8; // insert new requests as soon as the previous one is done - const bool hot_swap = true; + const bool hot_plug = false; + + // requests to simulate + const int32_t n_seq = 128; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("parallel", "log")); @@ -95,7 +98,6 @@ int main(int argc, char ** argv) { llama_backend_init(params.numa); llama_model * model = NULL; - llama_context * ctx = NULL; // load the target model @@ -130,11 +132,9 @@ int main(int argc, char ** argv) { int32_t n_total_prompt = 0; int32_t n_total_gen = 0; - float t_avg = 0.0f; + const auto t_main_start = ggml_time_us(); - const int32_t n_seq = 128; - - while (g_seq_id < n_seq + n_clients) { + while (true) { uint32_t n_tokens = 0; batch_token.clear(); @@ -148,7 +148,7 @@ int main(int argc, char ** argv) { } batch_token.push_back(client.sampled); - batch_pos.push_back(client.n_decoded); + batch_pos.push_back(client.n_decoded + client.n_prompt); batch_seq_id.push_back(client.seq_id); batch_logits.push_back(true); batch_clients.push_back(&client); @@ -161,12 +161,12 @@ int main(int argc, char ** argv) { llama_kv_cache_rm_tokens(ctx, -1, -1); } - if (hot_swap || batch_token.empty()) { + if (hot_plug || batch_token.empty()) { for (auto & client : clients) { - if (client.seq_id == -1) { + if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; client.t_start_prompt = ggml_time_us(); - client.t_start_gen = 0; + client.t_start_gen = 0; client.input = k_prompts[rand() % k_prompts.size()]; client.prompt = k_system + client.input + "\nAssistant:"; @@ -186,14 +186,21 @@ int main(int argc, char ** argv) { batch_logits.back() = true; client.n_prompt = prompt_tokens.size(); - client.n_decoded = prompt_tokens.size(); + client.n_decoded = 0; client.i_batch = batch_token.size() - 1; g_seq_id += 1; + if (hot_plug) { + break; + } } } } + if (batch_token.empty()) { + break; + } + // process in chunks of params.n_batch for (size_t i = 0; i < batch_token.size(); i += params.n_batch) { n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i)); @@ -223,7 +230,9 @@ int main(int argc, char ** argv) { const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i); - if (client.t_start_gen == 0) { + if (client.n_decoded == 1) { + // start measuring generation time after the first token to make sure all concurrent clients + // have their prompt already processed client.t_start_gen = ggml_time_us(); } @@ -238,9 +247,10 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); - if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || - client.response.find("User:") != std::string::npos || - client.response.find('\n') != std::string::npos) { + if (client.n_decoded > 2 && + (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || + client.response.find("User:") != std::string::npos || + client.response.find('\n') != std::string::npos)) { // basic reverse prompt const size_t pos = client.response.find("User:"); if (pos != std::string::npos) { @@ -252,18 +262,16 @@ int main(int argc, char ** argv) { const auto t_main_end = ggml_time_us(); printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n", - client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, + client.id, client.seq_id, client.n_prompt, client.n_decoded, (t_main_end - client.t_start_prompt) / 1e6, (double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6, - (double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6, - (double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6, + (double) (client.n_decoded ) / (t_main_end - client.t_start_gen) * 1e6, + (double) (client.n_decoded + client.n_prompt) / (t_main_end - client.t_start_prompt) * 1e6, ::trim(client.input).c_str(), ::trim(client.response).c_str()); n_total_prompt += client.n_prompt; - n_total_gen += client.n_decoded - client.n_prompt; - - t_avg += (t_main_end - client.t_start_prompt) / 1e6; + n_total_gen += client.n_decoded; client.seq_id = -1; } @@ -273,10 +281,12 @@ int main(int argc, char ** argv) { } } + const auto t_main_end = ggml_time_us(); + LOG_TEE("\n\n"); - LOG_TEE("Total prompt tokens: %d\n", n_total_prompt); - LOG_TEE("Total gen tokens: %d\n", n_total_gen); - LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq); + LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); + LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); + LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); LOG_TEE("\n\n");