parallel : various improvements

This commit is contained in:
Georgi Gerganov 2023-09-19 12:29:37 +03:00
parent 467e307931
commit 36714e16d0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -80,10 +80,13 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_clients = 4; const int n_clients = 8;
// insert new requests as soon as the previous one is done // insert new requests as soon as the previous one is done
const bool hot_swap = true; const bool hot_plug = false;
// requests to simulate
const int32_t n_seq = 128;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log")); log_set_target(log_filename_generator("parallel", "log"));
@ -95,7 +98,6 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa); llama_backend_init(params.numa);
llama_model * model = NULL; llama_model * model = NULL;
llama_context * ctx = NULL; llama_context * ctx = NULL;
// load the target model // load the target model
@ -130,11 +132,9 @@ int main(int argc, char ** argv) {
int32_t n_total_prompt = 0; int32_t n_total_prompt = 0;
int32_t n_total_gen = 0; int32_t n_total_gen = 0;
float t_avg = 0.0f; const auto t_main_start = ggml_time_us();
const int32_t n_seq = 128; while (true) {
while (g_seq_id < n_seq + n_clients) {
uint32_t n_tokens = 0; uint32_t n_tokens = 0;
batch_token.clear(); batch_token.clear();
@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
} }
batch_token.push_back(client.sampled); batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded); batch_pos.push_back(client.n_decoded + client.n_prompt);
batch_seq_id.push_back(client.seq_id); batch_seq_id.push_back(client.seq_id);
batch_logits.push_back(true); batch_logits.push_back(true);
batch_clients.push_back(&client); batch_clients.push_back(&client);
@ -161,9 +161,9 @@ int main(int argc, char ** argv) {
llama_kv_cache_rm_tokens(ctx, -1, -1); llama_kv_cache_rm_tokens(ctx, -1, -1);
} }
if (hot_swap || batch_token.empty()) { if (hot_plug || batch_token.empty()) {
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1) { if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id; client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us(); client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0; client.t_start_gen = 0;
@ -186,13 +186,20 @@ int main(int argc, char ** argv) {
batch_logits.back() = true; batch_logits.back() = true;
client.n_prompt = prompt_tokens.size(); client.n_prompt = prompt_tokens.size();
client.n_decoded = prompt_tokens.size(); client.n_decoded = 0;
client.i_batch = batch_token.size() - 1; client.i_batch = batch_token.size() - 1;
g_seq_id += 1; g_seq_id += 1;
if (hot_plug) {
break;
} }
} }
} }
}
if (batch_token.empty()) {
break;
}
// process in chunks of params.n_batch // process in chunks of params.n_batch
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) { for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
@ -223,7 +230,9 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i); const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
if (client.t_start_gen == 0) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
// have their prompt already processed
client.t_start_gen = ggml_time_us(); client.t_start_gen = ggml_time_us();
} }
@ -238,9 +247,10 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || if (client.n_decoded > 2 &&
(id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
client.response.find("User:") != std::string::npos || client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos) { client.response.find('\n') != std::string::npos)) {
// basic reverse prompt // basic reverse prompt
const size_t pos = client.response.find("User:"); const size_t pos = client.response.find("User:");
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -252,18 +262,16 @@ int main(int argc, char ** argv) {
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n", printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, client.id, client.seq_id, client.n_prompt, client.n_decoded,
(t_main_end - client.t_start_prompt) / 1e6, (t_main_end - client.t_start_prompt) / 1e6,
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6, (double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6, (double) (client.n_decoded ) / (t_main_end - client.t_start_gen) * 1e6,
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6, (double) (client.n_decoded + client.n_prompt) / (t_main_end - client.t_start_prompt) * 1e6,
::trim(client.input).c_str(), ::trim(client.input).c_str(),
::trim(client.response).c_str()); ::trim(client.response).c_str());
n_total_prompt += client.n_prompt; n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded - client.n_prompt; n_total_gen += client.n_decoded;
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
client.seq_id = -1; client.seq_id = -1;
} }
@ -273,10 +281,12 @@ int main(int argc, char ** argv) {
} }
} }
const auto t_main_end = ggml_time_us();
LOG_TEE("\n\n"); LOG_TEE("\n\n");
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt); LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
LOG_TEE("Total gen tokens: %d\n", n_total_gen); LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq); LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
LOG_TEE("\n\n"); LOG_TEE("\n\n");