mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 19:33:58 +01:00
parallel : various improvements
This commit is contained in:
parent
467e307931
commit
36714e16d0
@ -80,10 +80,13 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_clients = 4;
|
const int n_clients = 8;
|
||||||
|
|
||||||
// insert new requests as soon as the previous one is done
|
// insert new requests as soon as the previous one is done
|
||||||
const bool hot_swap = true;
|
const bool hot_plug = false;
|
||||||
|
|
||||||
|
// requests to simulate
|
||||||
|
const int32_t n_seq = 128;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("parallel", "log"));
|
log_set_target(log_filename_generator("parallel", "log"));
|
||||||
@ -95,7 +98,6 @@ int main(int argc, char ** argv) {
|
|||||||
llama_backend_init(params.numa);
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
llama_model * model = NULL;
|
llama_model * model = NULL;
|
||||||
|
|
||||||
llama_context * ctx = NULL;
|
llama_context * ctx = NULL;
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
@ -130,11 +132,9 @@ int main(int argc, char ** argv) {
|
|||||||
int32_t n_total_prompt = 0;
|
int32_t n_total_prompt = 0;
|
||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
|
|
||||||
float t_avg = 0.0f;
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
const int32_t n_seq = 128;
|
while (true) {
|
||||||
|
|
||||||
while (g_seq_id < n_seq + n_clients) {
|
|
||||||
uint32_t n_tokens = 0;
|
uint32_t n_tokens = 0;
|
||||||
|
|
||||||
batch_token.clear();
|
batch_token.clear();
|
||||||
@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
batch_token.push_back(client.sampled);
|
batch_token.push_back(client.sampled);
|
||||||
batch_pos.push_back(client.n_decoded);
|
batch_pos.push_back(client.n_decoded + client.n_prompt);
|
||||||
batch_seq_id.push_back(client.seq_id);
|
batch_seq_id.push_back(client.seq_id);
|
||||||
batch_logits.push_back(true);
|
batch_logits.push_back(true);
|
||||||
batch_clients.push_back(&client);
|
batch_clients.push_back(&client);
|
||||||
@ -161,12 +161,12 @@ int main(int argc, char ** argv) {
|
|||||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hot_swap || batch_token.empty()) {
|
if (hot_plug || batch_token.empty()) {
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
if (client.seq_id == -1) {
|
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||||
client.seq_id = g_seq_id;
|
client.seq_id = g_seq_id;
|
||||||
client.t_start_prompt = ggml_time_us();
|
client.t_start_prompt = ggml_time_us();
|
||||||
client.t_start_gen = 0;
|
client.t_start_gen = 0;
|
||||||
|
|
||||||
client.input = k_prompts[rand() % k_prompts.size()];
|
client.input = k_prompts[rand() % k_prompts.size()];
|
||||||
client.prompt = k_system + client.input + "\nAssistant:";
|
client.prompt = k_system + client.input + "\nAssistant:";
|
||||||
@ -186,14 +186,21 @@ int main(int argc, char ** argv) {
|
|||||||
batch_logits.back() = true;
|
batch_logits.back() = true;
|
||||||
|
|
||||||
client.n_prompt = prompt_tokens.size();
|
client.n_prompt = prompt_tokens.size();
|
||||||
client.n_decoded = prompt_tokens.size();
|
client.n_decoded = 0;
|
||||||
client.i_batch = batch_token.size() - 1;
|
client.i_batch = batch_token.size() - 1;
|
||||||
|
|
||||||
g_seq_id += 1;
|
g_seq_id += 1;
|
||||||
|
if (hot_plug) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (batch_token.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
|
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
|
||||||
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
|
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
|
||||||
@ -223,7 +230,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||||
|
|
||||||
if (client.t_start_gen == 0) {
|
if (client.n_decoded == 1) {
|
||||||
|
// start measuring generation time after the first token to make sure all concurrent clients
|
||||||
|
// have their prompt already processed
|
||||||
client.t_start_gen = ggml_time_us();
|
client.t_start_gen = ggml_time_us();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,9 +247,10 @@ int main(int argc, char ** argv) {
|
|||||||
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
||||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||||
|
|
||||||
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
|
if (client.n_decoded > 2 &&
|
||||||
client.response.find("User:") != std::string::npos ||
|
(id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
|
||||||
client.response.find('\n') != std::string::npos) {
|
client.response.find("User:") != std::string::npos ||
|
||||||
|
client.response.find('\n') != std::string::npos)) {
|
||||||
// basic reverse prompt
|
// basic reverse prompt
|
||||||
const size_t pos = client.response.find("User:");
|
const size_t pos = client.response.find("User:");
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
@ -252,18 +262,16 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||||
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
|
client.id, client.seq_id, client.n_prompt, client.n_decoded,
|
||||||
(t_main_end - client.t_start_prompt) / 1e6,
|
(t_main_end - client.t_start_prompt) / 1e6,
|
||||||
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
|
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
|
||||||
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
|
(double) (client.n_decoded ) / (t_main_end - client.t_start_gen) * 1e6,
|
||||||
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
|
(double) (client.n_decoded + client.n_prompt) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||||
::trim(client.input).c_str(),
|
::trim(client.input).c_str(),
|
||||||
::trim(client.response).c_str());
|
::trim(client.response).c_str());
|
||||||
|
|
||||||
n_total_prompt += client.n_prompt;
|
n_total_prompt += client.n_prompt;
|
||||||
n_total_gen += client.n_decoded - client.n_prompt;
|
n_total_gen += client.n_decoded;
|
||||||
|
|
||||||
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
|
|
||||||
|
|
||||||
client.seq_id = -1;
|
client.seq_id = -1;
|
||||||
}
|
}
|
||||||
@ -273,10 +281,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt);
|
LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
|
||||||
LOG_TEE("Total gen tokens: %d\n", n_total_gen);
|
LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
|
||||||
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq);
|
LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
|
||||||
|
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user