mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
parallel : try smaller batches when the KV cache is fragmented
This commit is contained in:
parent
ddad227782
commit
806d397c1a
@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_clients = 8;
|
const int n_clients = 8;
|
||||||
|
|
||||||
// insert new requests as soon as the previous one is done
|
// insert new requests as soon as the previous one is done
|
||||||
const bool hot_plug = false;
|
const bool hot_plug = true;
|
||||||
|
|
||||||
// requests to simulate
|
// requests to simulate
|
||||||
const int32_t n_seq = 128;
|
const int32_t n_seq = 128;
|
||||||
@ -202,8 +202,10 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
|
int32_t n_batch = params.n_batch;
|
||||||
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
|
|
||||||
|
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
|
||||||
|
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
|
||||||
|
|
||||||
llama_batch batch = {
|
llama_batch batch = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
@ -216,10 +218,22 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
LOG_TEE("%s : failed to decode batch\n", __func__);
|
if (n_batch == 1) {
|
||||||
return 1;
|
LOG_TEE("%s : failed to decode batch\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||||
|
|
||||||
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
|
n_batch /= 2;
|
||||||
|
i -= n_batch;
|
||||||
|
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_TEE("%s : decoded batch of %d tokens\n", __func__, n_tokens);
|
||||||
|
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
|
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
|
||||||
continue;
|
continue;
|
||||||
|
Loading…
Reference in New Issue
Block a user