parallel : try smaller batches when the KV cache is fragmented

This commit is contained in:
Georgi Gerganov 2023-09-19 13:21:36 +03:00
parent ddad227782
commit 806d397c1a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
const int n_clients = 8;
// insert new requests as soon as the previous one is done
const bool hot_plug = false;
const bool hot_plug = true;
// requests to simulate
const int32_t n_seq = 128;
@ -202,8 +202,10 @@ int main(int argc, char ** argv) {
}
// process in chunks of params.n_batch
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
llama_batch batch = {
n_tokens,
@ -216,10 +218,22 @@ int main(int argc, char ** argv) {
};
if (llama_decode(ctx, batch, params.n_threads)) {
LOG_TEE("%s : failed to decode batch\n", __func__);
return 1;
if (n_batch == 1) {
LOG_TEE("%s : failed to decode batch\n", __func__);
return 1;
}
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
}
LOG_TEE("%s : decoded batch of %d tokens\n", __func__, n_tokens);
for (auto & client : clients) {
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
continue;