server : fix parallel speculative decoding (#10513)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-26 13:36:40 +02:00 committed by GitHub
parent 811872a59d
commit 84e1c33cde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2267,12 +2267,7 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
llama_token id; llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
{
completion_token_output result;
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1; slot.i_batch = -1;
@ -2285,6 +2280,7 @@ struct server_context {
metrics.on_prompt_eval(slot); metrics.on_prompt_eval(slot);
} }
completion_token_output result;
result.tok = id; result.tok = id;
const auto * cur_p = common_sampler_get_candidates(slot.smpl); const auto * cur_p = common_sampler_get_candidates(slot.smpl);
@ -2306,11 +2302,14 @@ struct server_context {
} }
} }
// check if the slot supports speculative decoding // do speculative decoding
if (!slot.can_speculate()) { for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue; continue;
} }
llama_token id = slot.sampled;
struct common_speculative_params params_spec; struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max; params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;