server : add speculative decoding support (#10455)

* server : add speculative decoding support

ggml-ci

* server : add helper function slot.can_speculate()

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-25 16:31:38 +02:00 committed by GitHub
parent 5931c1f233
commit 9ca2e67762
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,10 +2,11 @@
#include "arg.h" #include "arg.h"
#include "common.h" #include "common.h"
#include "log.h"
#include "sampling.h"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
@ -121,12 +122,21 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
}; };
struct server_slot { struct server_slot {
int id; int id;
int id_task = -1; int id_task = -1;
llama_batch batch_spec;
llama_context * ctx_dft = nullptr;
common_speculative * spec = nullptr;
// the index relative to completion multi-task request // the index relative to completion multi-task request
size_t index = 0; size_t index = 0;
@ -175,7 +185,6 @@ struct server_slot {
// sampling // sampling
json json_schema; json json_schema;
struct common_params_sampling sparams;
struct common_sampler * smpl = nullptr; struct common_sampler * smpl = nullptr;
llama_token sampled; llama_token sampled;
@ -212,7 +221,7 @@ struct server_slot {
generated_token_probs.clear(); generated_token_probs.clear();
} }
bool has_budget(common_params &global_params) { bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) { if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless return true; // limitless
} }
@ -232,6 +241,10 @@ struct server_slot {
return state != SLOT_STATE_IDLE; return state != SLOT_STATE_IDLE;
} }
bool can_speculate() const {
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) { void add_token(const completion_token_output & token) {
if (!is_processing()) { if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n"); SLT_WRN(*this, "%s", "slot is not processing\n");
@ -591,11 +604,14 @@ struct server_response {
}; };
struct server_context { struct server_context {
common_params params_base;
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras; std::vector<common_lora_adapter_container> loras;
common_params params; llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {}; llama_batch batch = {};
@ -628,27 +644,41 @@ struct server_context {
model = nullptr; model = nullptr;
} }
if (model_dft) {
llama_free_model(model_dft);
model_dft = nullptr;
}
// Clear any sampling context // Clear any sampling context
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl); common_sampler_free(slot.smpl);
} slot.smpl = nullptr;
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
} }
llama_batch_free(batch); llama_batch_free(batch);
} }
bool load_model(const common_params & params_) { bool load_model(const common_params & params) {
params = params_; SRV_INF("loading model '%s'\n", params.model.c_str());
common_init_result llama_init = common_init_from_params(params); params_base = params;
common_init_result llama_init = common_init_from_params(params_base);
model = llama_init.model; model = llama_init.model;
ctx = llama_init.context; ctx = llama_init.context;
loras = llama_init.lora_adapters; loras = llama_init.lora_adapters;
if (model == nullptr) { if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params.model.c_str()); SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
return false; return false;
} }
@ -657,6 +687,40 @@ struct server_context {
add_bos_token = llama_add_bos_token(model); add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model); has_eos_token = !llama_add_eos_token(model);
if (!params_base.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
auto params_dft = params_base;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
common_init_result llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
llama_free (llama_init_dft.context);
llama_free_model(llama_init_dft.model);
return false;
}
cparams_dft = common_context_params_to_llama(params_base);
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
// the context is not needed - we will create one for each slot
llama_free(llama_init_dft.context);
}
return true; return true;
} }
@ -674,20 +738,36 @@ struct server_context {
} }
void init() { void init() {
const int32_t n_ctx_slot = n_ctx / params.n_parallel; const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel); SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
for (int i = 0; i < params.n_parallel; i++) { for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot; server_slot slot;
slot.id = i; slot.id = i;
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict; slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
slot.sparams = params.sampling; slot.params.sampling = params_base.sampling;
slot.callback_on_release = [this](int) { slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task(); queue_tasks.pop_deferred_task();
@ -707,7 +787,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx); const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed // only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
} }
metrics.init(); metrics.init();
@ -786,9 +866,11 @@ struct server_context {
} }
bool launch_slot_with_task(server_slot & slot, const server_task & task) { bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
auto default_sparams = params.sampling; slot_params defaults;
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
const auto & data = task.data; const auto & data = task.data;
if (data.count("__oaicompat") != 0) { if (data.count("__oaicompat") != 0) {
@ -801,40 +883,46 @@ struct server_context {
slot.params.stream = json_value(data, "stream", false); slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
if (slot.sparams.dry_base < 1.0f) slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
{ slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
slot.sparams.dry_base = default_sparams.dry_base; slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
if (slot.params.sampling.dry_base < 1.0f) {
slot.params.sampling.dry_base = defaults.sampling.dry_base;
} }
// sequence breakers for DRY // sequence breakers for DRY
@ -843,8 +931,8 @@ struct server_context {
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) { if (data.contains("dry_sequence_breakers")) {
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>()); slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.sparams.dry_sequence_breakers.empty()) { if (slot.params.sampling.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
@ -859,13 +947,13 @@ struct server_context {
if (data.contains("json_schema") && !data.contains("grammar")) { if (data.contains("json_schema") && !data.contains("grammar")) {
try { try {
auto schema = json_value(data, "json_schema", json::object()); auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema); slot.params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) { } catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
} else { } else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
} }
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@ -875,10 +963,10 @@ struct server_context {
} }
{ {
slot.sparams.logit_bias.clear(); slot.params.sampling.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) { if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
} }
const auto & logit_bias = data.find("logit_bias"); const auto & logit_bias = data.find("logit_bias");
@ -899,12 +987,12 @@ struct server_context {
if (el[0].is_number_integer()) { if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>(); llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) { if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias.push_back({tok, bias}); slot.params.sampling.logit_bias.push_back({tok, bias});
} }
} else if (el[0].is_string()) { } else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false); auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) { for (auto tok : toks) {
slot.sparams.logit_bias.push_back({tok, bias}); slot.params.sampling.logit_bias.push_back({tok, bias});
} }
} }
} }
@ -935,16 +1023,16 @@ struct server_context {
sampler_names.emplace_back(name); sampler_names.emplace_back(name);
} }
} }
slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false); slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){ } else if (samplers->is_string()){
std::string sampler_string; std::string sampler_string;
for (const auto & name : *samplers) { for (const auto & name : *samplers) {
sampler_string += name; sampler_string += name;
} }
slot.sparams.samplers = common_sampler_types_from_chars(sampler_string); slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
} }
} else { } else {
slot.sparams.samplers = default_sparams.samplers; slot.params.sampling.samplers = defaults.sampling.samplers;
} }
} }
@ -953,7 +1041,7 @@ struct server_context {
common_sampler_free(slot.smpl); common_sampler_free(slot.smpl);
} }
slot.smpl = common_sampler_init(model, slot.sparams); slot.smpl = common_sampler_init(model, slot.params.sampling);
if (slot.smpl == nullptr) { if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@ -961,6 +1049,12 @@ struct server_context {
} }
} }
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
}
slot.state = SLOT_STATE_STARTED; slot.state = SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n"); SLT_INF(slot, "%s", "processing task\n");
@ -978,7 +1072,7 @@ struct server_context {
bool process_token(completion_token_output & result, server_slot & slot) { bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special); const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
slot.sampled = result.tok; slot.sampled = result.tok;
// search stop word and delete it // search stop word and delete it
@ -1043,7 +1137,7 @@ struct server_context {
} }
// check the limits // check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) { if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stopped_limit = true; slot.stopped_limit = true;
slot.has_next_token = false; slot.has_next_token = false;
@ -1136,50 +1230,54 @@ struct server_context {
json get_formated_generation(const server_slot & slot) const { json get_formated_generation(const server_slot & slot) const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
samplers.reserve(slot.sparams.samplers.size()); samplers.reserve(slot.params.sampling.samplers.size());
for (const auto & sampler : slot.sparams.samplers) { for (const auto & sampler : slot.params.sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler)); samplers.emplace_back(common_sampler_type_to_str(sampler));
} }
return json { return json {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, // Server configured n_predict {"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias}, {"model", params_base.model_alias},
{"seed", slot.sparams.seed}, {"seed", slot.params.sampling.seed},
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0}, {"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp}, {"temperature", slot.params.sampling.temp},
{"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_range", slot.params.sampling.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent}, {"dynatemp_exponent", slot.params.sampling.dynatemp_exponent},
{"top_k", slot.sparams.top_k}, {"top_k", slot.params.sampling.top_k},
{"top_p", slot.sparams.top_p}, {"top_p", slot.params.sampling.top_p},
{"min_p", slot.sparams.min_p}, {"min_p", slot.params.sampling.min_p},
{"xtc_probability", slot.sparams.xtc_probability}, {"xtc_probability", slot.params.sampling.xtc_probability},
{"xtc_threshold", slot.sparams.xtc_threshold}, {"xtc_threshold", slot.params.sampling.xtc_threshold},
{"typical_p", slot.sparams.typ_p}, {"typical_p", slot.params.sampling.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_last_n", slot.params.sampling.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat}, {"repeat_penalty", slot.params.sampling.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present}, {"presence_penalty", slot.params.sampling.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq}, {"frequency_penalty", slot.params.sampling.penalty_freq},
{"dry_multiplier", slot.sparams.dry_multiplier}, {"dry_multiplier", slot.params.sampling.dry_multiplier},
{"dry_base", slot.sparams.dry_base}, {"dry_base", slot.params.sampling.dry_base},
{"dry_allowed_length", slot.sparams.dry_allowed_length}, {"dry_allowed_length", slot.params.sampling.dry_allowed_length},
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, {"dry_penalty_last_n", slot.params.sampling.dry_penalty_last_n},
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, {"dry_sequence_breakers", slot.params.sampling.dry_sequence_breakers},
{"mirostat", slot.sparams.mirostat}, {"mirostat", slot.params.sampling.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_tau", slot.params.sampling.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta}, {"mirostat_eta", slot.params.sampling.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl}, {"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict {"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard}, {"n_discard", slot.params.n_discard},
{"ignore_eos", slot.sparams.ignore_eos}, {"ignore_eos", slot.params.sampling.ignore_eos},
{"stream", slot.params.stream}, {"stream", slot.params.stream},
//{"logit_bias", slot.sparams.logit_bias}, //{"logit_bias", slot.params.sampling.logit_bias},
{"n_probs", slot.sparams.n_probs}, {"n_probs", slot.params.sampling.n_probs},
{"min_keep", slot.sparams.min_keep}, {"min_keep", slot.params.sampling.min_keep},
{"grammar", slot.sparams.grammar}, {"grammar", slot.params.sampling.grammar},
{"samplers", samplers}, {"samplers", samplers},
{"speculative", slot.can_speculate()},
{"speculative.n_max", slot.params.speculative.n_max},
{"speculative.n_min", slot.params.speculative.n_min},
{"speculative.p_min", slot.params.speculative.p_min},
}; };
} }
@ -1216,7 +1314,7 @@ struct server_context {
{"index", slot.index}, {"index", slot.index},
}; };
if (slot.sparams.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@ -1249,7 +1347,7 @@ struct server_context {
{"content", !slot.params.stream ? slot.generated_text : ""}, {"content", !slot.params.stream ? slot.generated_text : ""},
{"id_slot", slot.id}, {"id_slot", slot.id},
{"stop", true}, {"stop", true},
{"model", params.model_alias}, {"model", params_base.model_alias},
{"tokens_predicted", slot.n_decoded}, {"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens}, {"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)}, {"generation_settings", get_formated_generation(slot)},
@ -1265,7 +1363,7 @@ struct server_context {
{"index", slot.index}, {"index", slot.index},
}; };
if (slot.sparams.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
std::vector<completion_token_output> probs; std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) { if (!slot.params.stream && slot.stopped_word) {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
@ -1422,10 +1520,10 @@ struct server_context {
data.at("input_prefix"), data.at("input_prefix"),
data.at("input_suffix"), data.at("input_suffix"),
data.at("input_extra"), data.at("input_extra"),
params.n_batch, params_base.n_batch,
params.n_predict, params_base.n_predict,
slots[0].n_ctx, // TODO: there should be a better way slots[0].n_ctx, // TODO: there should be a better way
params.spm_infill, params_base.spm_infill,
tokenized_prompts[i] tokenized_prompts[i]
); );
create_task(data, tokens); create_task(data, tokens);
@ -1798,7 +1896,7 @@ struct server_context {
// TODO: simplify and improve // TODO: simplify and improve
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (!params.ctx_shift) { if (!params_base.ctx_shift) {
// this check is redundant (for good) // this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token() // we should never get here, because generation should already stopped in process_token()
slot.release(); slot.release();
@ -1864,7 +1962,7 @@ struct server_context {
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) { if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
// this slot still has a prompt to be processed // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@ -1917,7 +2015,7 @@ struct server_context {
continue; continue;
} }
} else { } else {
if (!params.ctx_shift) { if (!params_base.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size // if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation // TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase // context shift should be applied only during the generation phase
@ -1963,11 +2061,11 @@ struct server_context {
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position // reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) { if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt size_t head_p = slot.n_past; // current prompt
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past); SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
while (head_c < slot.cache_tokens.size() && while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) { head_p < prompt_tokens.size()) {
@ -1980,7 +2078,7 @@ struct server_context {
n_match++; n_match++;
} }
if (n_match >= (size_t) params.n_cache_reuse) { if (n_match >= (size_t) params_base.n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) { //for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@ -2168,8 +2266,14 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
llama_token id;
{
completion_token_output result; completion_token_output result;
const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true); common_sampler_accept(slot.smpl, id, true);
@ -2184,7 +2288,7 @@ struct server_context {
const auto * cur_p = common_sampler_get_candidates(slot.smpl); const auto * cur_p = common_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({ result.probs.push_back({
cur_p->data[i].id, cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p, i >= cur_p->size ? 0.0f : cur_p->data[i].p,
@ -2197,9 +2301,64 @@ struct server_context {
slot.print_timings(); slot.print_timings();
send_final_response(slot); send_final_response(slot);
metrics.on_prediction(slot); metrics.on_prediction(slot);
continue;
}
} }
slot.i_batch = -1; // check if the slot supports speculative decoding
if (!slot.can_speculate()) {
continue;
}
struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
continue;
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
} }
} }
@ -2697,7 +2856,7 @@ int main(int argc, char ** argv) {
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json data = { json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel }, { "total_slots", ctx_server.params_base.n_parallel },
{ "chat_template", llama_get_chat_template(ctx_server.model) }, { "chat_template", llama_get_chat_template(ctx_server.model) },
}; };
@ -2705,7 +2864,7 @@ int main(int argc, char ** argv) {
}; };
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.endpoint_props) { if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2718,7 +2877,7 @@ int main(int argc, char ** argv) {
}; };
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding) { if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2824,7 +2983,7 @@ int main(int argc, char ** argv) {
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) { if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -3001,7 +3160,7 @@ int main(int argc, char ** argv) {
}; };
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.reranking || ctx_server.params.embedding) { if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }