diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c0ea4faf7..9c86407c2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2267,49 +2267,48 @@ struct server_context { continue; // continue loop of slots } - llama_token id; + llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); - { - completion_token_output result; + slot.i_batch = -1; - id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + common_sampler_accept(slot.smpl, id, true); - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - result.tok = id; - - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - - for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - continue; - } + slot.n_decoded += 1; + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); } - // check if the slot supports speculative decoding - if (!slot.can_speculate()) { + completion_token_output result; + result.tok = id; + + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + + for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { + result.probs.push_back({ + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, + }); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); continue; } + } + + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + llama_token id = slot.sampled; struct common_speculative_params params_spec; params_spec.n_draft = slot.params.speculative.n_max;