server : fix speculative decoding with context shift (#10641)

* server : fix speculative decoding with context shift

ggml-ci

* server : take into account speculative limits

ggml-ci

* server : add tests
This commit is contained in:
Georgi Gerganov 2024-12-04 22:38:20 +02:00 committed by GitHub
parent 59f4db1088
commit 1da7b76569
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 2 deletions

View File

@ -921,6 +921,8 @@ struct server_context {
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
if (slot.params.sampling.dry_base < 1.0f) {
slot.params.sampling.dry_base = defaults.sampling.dry_base;
@ -2322,10 +2324,29 @@ struct server_context {
continue;
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < slot.params.speculative.n_min) {
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
continue;
}
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
@ -2333,6 +2354,8 @@ struct server_context {
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
continue;
}
@ -2344,6 +2367,8 @@ struct server_context {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
@ -2372,7 +2397,7 @@ struct server_context {
}
}
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}
}

View File

@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
last_content = res.body["content"]
def test_slot_ctx_not_exceeded():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
def test_with_ctx_shift():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"n_predict": 64,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
assert res.body["tokens_predicted"] == 64
assert res.body["truncated"] == True
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),