mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
parallel : various improvements
This commit is contained in:
parent
467e307931
commit
36714e16d0
@ -80,10 +80,13 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_clients = 4;
|
||||
const int n_clients = 8;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_swap = true;
|
||||
const bool hot_plug = false;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = 128;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
@ -95,7 +98,6 @@ int main(int argc, char ** argv) {
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
|
||||
llama_context * ctx = NULL;
|
||||
|
||||
// load the target model
|
||||
@ -130,11 +132,9 @@ int main(int argc, char ** argv) {
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
||||
float t_avg = 0.0f;
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
const int32_t n_seq = 128;
|
||||
|
||||
while (g_seq_id < n_seq + n_clients) {
|
||||
while (true) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
batch_token.clear();
|
||||
@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded);
|
||||
batch_pos.push_back(client.n_decoded + client.n_prompt);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_logits.push_back(true);
|
||||
batch_clients.push_back(&client);
|
||||
@ -161,12 +161,12 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
}
|
||||
|
||||
if (hot_swap || batch_token.empty()) {
|
||||
if (hot_plug || batch_token.empty()) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
@ -186,14 +186,21 @@ int main(int argc, char ** argv) {
|
||||
batch_logits.back() = true;
|
||||
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_decoded = prompt_tokens.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
g_seq_id += 1;
|
||||
if (hot_plug) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch_token.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
|
||||
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
|
||||
@ -223,7 +230,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||
|
||||
if (client.t_start_gen == 0) {
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
// have their prompt already processed
|
||||
client.t_start_gen = ggml_time_us();
|
||||
}
|
||||
|
||||
@ -238,9 +247,10 @@ int main(int argc, char ** argv) {
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos) {
|
||||
if (client.n_decoded > 2 &&
|
||||
(id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos)) {
|
||||
// basic reverse prompt
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
@ -252,18 +262,16 @@ int main(int argc, char ** argv) {
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded,
|
||||
(t_main_end - client.t_start_prompt) / 1e6,
|
||||
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
|
||||
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
|
||||
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||
(double) (client.n_decoded ) / (t_main_end - client.t_start_gen) * 1e6,
|
||||
(double) (client.n_decoded + client.n_prompt) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||
::trim(client.input).c_str(),
|
||||
::trim(client.response).c_str());
|
||||
|
||||
n_total_prompt += client.n_prompt;
|
||||
n_total_gen += client.n_decoded - client.n_prompt;
|
||||
|
||||
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
|
||||
n_total_gen += client.n_decoded;
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
@ -273,10 +281,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt);
|
||||
LOG_TEE("Total gen tokens: %d\n", n_total_gen);
|
||||
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq);
|
||||
LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
|
||||
LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
|
||||
LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user