diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 3c3fe6ddb..c35552e4a 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { const int n_clients = 8; // insert new requests as soon as the previous one is done - const bool hot_plug = false; + const bool hot_plug = true; // requests to simulate const int32_t n_seq = 128; @@ -202,8 +202,10 @@ int main(int argc, char ** argv) { } // process in chunks of params.n_batch - for (size_t i = 0; i < batch_token.size(); i += params.n_batch) { - n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i)); + int32_t n_batch = params.n_batch; + + for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) { + n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i)); llama_batch batch = { n_tokens, @@ -216,10 +218,22 @@ int main(int argc, char ** argv) { }; if (llama_decode(ctx, batch, params.n_threads)) { - LOG_TEE("%s : failed to decode batch\n", __func__); - return 1; + if (n_batch == 1) { + LOG_TEE("%s : failed to decode batch\n", __func__); + return 1; + } + + LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + + continue; } + LOG_TEE("%s : decoded batch of %d tokens\n", __func__, n_tokens); + for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { continue;