parallel : try smaller batches when the KV cache is fragmented

This commit is contained in:
Georgi Gerganov 2023-09-19 13:21:36 +03:00
parent ddad227782
commit 806d397c1a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
const int n_clients = 8; const int n_clients = 8;
// insert new requests as soon as the previous one is done // insert new requests as soon as the previous one is done
const bool hot_plug = false; const bool hot_plug = true;
// requests to simulate // requests to simulate
const int32_t n_seq = 128; const int32_t n_seq = 128;
@ -202,8 +202,10 @@ int main(int argc, char ** argv) {
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) { int32_t n_batch = params.n_batch;
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
llama_batch batch = { llama_batch batch = {
n_tokens, n_tokens,
@ -216,10 +218,22 @@ int main(int argc, char ** argv) {
}; };
if (llama_decode(ctx, batch, params.n_threads)) { if (llama_decode(ctx, batch, params.n_threads)) {
LOG_TEE("%s : failed to decode batch\n", __func__); if (n_batch == 1) {
return 1; LOG_TEE("%s : failed to decode batch\n", __func__);
return 1;
}
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
} }
LOG_TEE("%s : decoded batch of %d tokens\n", __func__, n_tokens);
for (auto & client : clients) { for (auto & client : clients) {
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
continue; continue;