mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server : add speculative decoding support (#10455)
* server : add speculative decoding support ggml-ci * server : add helper function slot.can_speculate() ggml-ci
This commit is contained in:
parent
5931c1f233
commit
9ca2e67762
@ -2,10 +2,11 @@
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
@ -121,12 +122,21 @@ struct slot_params {
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
|
||||
llama_batch batch_spec;
|
||||
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
// the index relative to completion multi-task request
|
||||
size_t index = 0;
|
||||
|
||||
@ -175,7 +185,6 @@ struct server_slot {
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
struct common_params_sampling sparams;
|
||||
struct common_sampler * smpl = nullptr;
|
||||
|
||||
llama_token sampled;
|
||||
@ -212,7 +221,7 @@ struct server_slot {
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
||||
bool has_budget(common_params &global_params) {
|
||||
bool has_budget(const common_params & global_params) {
|
||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||
return true; // limitless
|
||||
}
|
||||
@ -232,6 +241,10 @@ struct server_slot {
|
||||
return state != SLOT_STATE_IDLE;
|
||||
}
|
||||
|
||||
bool can_speculate() const {
|
||||
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
if (!is_processing()) {
|
||||
SLT_WRN(*this, "%s", "slot is not processing\n");
|
||||
@ -591,11 +604,14 @@ struct server_response {
|
||||
};
|
||||
|
||||
struct server_context {
|
||||
common_params params_base;
|
||||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
std::vector<common_lora_adapter_container> loras;
|
||||
|
||||
common_params params;
|
||||
llama_model * model_dft = nullptr;
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
|
||||
@ -628,27 +644,41 @@ struct server_context {
|
||||
model = nullptr;
|
||||
}
|
||||
|
||||
if (model_dft) {
|
||||
llama_free_model(model_dft);
|
||||
model_dft = nullptr;
|
||||
}
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.smpl != nullptr) {
|
||||
common_sampler_free(slot.smpl);
|
||||
}
|
||||
common_sampler_free(slot.smpl);
|
||||
slot.smpl = nullptr;
|
||||
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
|
||||
llama_batch_free(slot.batch_spec);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
bool load_model(const common_params & params_) {
|
||||
params = params_;
|
||||
bool load_model(const common_params & params) {
|
||||
SRV_INF("loading model '%s'\n", params.model.c_str());
|
||||
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
params_base = params;
|
||||
|
||||
common_init_result llama_init = common_init_from_params(params_base);
|
||||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
loras = llama_init.lora_adapters;
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -657,6 +687,40 @@ struct server_context {
|
||||
add_bos_token = llama_add_bos_token(model);
|
||||
has_eos_token = !llama_add_eos_token(model);
|
||||
|
||||
if (!params_base.speculative.model.empty()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
|
||||
common_init_result llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft.model;
|
||||
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
|
||||
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
|
||||
|
||||
llama_free (llama_init_dft.context);
|
||||
llama_free_model(llama_init_dft.model);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_base);
|
||||
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_free(llama_init_dft.context);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -674,20 +738,36 @@ struct server_context {
|
||||
}
|
||||
|
||||
void init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
||||
|
||||
SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
|
||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||
|
||||
for (int i = 0; i < params.n_parallel; i++) {
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
slot.id = i;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params.n_predict;
|
||||
slot.n_predict = params_base.n_predict;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
|
||||
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return;
|
||||
}
|
||||
|
||||
slot.spec = common_speculative_init(slot.ctx_dft);
|
||||
if (slot.spec == nullptr) {
|
||||
SRV_ERR("%s", "failed to create speculator\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||
|
||||
slot.sparams = params.sampling;
|
||||
slot.params.sampling = params_base.sampling;
|
||||
|
||||
slot.callback_on_release = [this](int) {
|
||||
queue_tasks.pop_deferred_task();
|
||||
@ -707,7 +787,7 @@ struct server_context {
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
@ -786,9 +866,11 @@ struct server_context {
|
||||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
slot_params default_params;
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
auto default_sparams = params.sampling;
|
||||
slot_params defaults;
|
||||
defaults.sampling = params_base.sampling;
|
||||
defaults.speculative = params_base.speculative;
|
||||
|
||||
const auto & data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
@ -799,42 +881,48 @@ struct server_context {
|
||||
slot.oaicompat_model = "";
|
||||
}
|
||||
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
||||
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
||||
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
|
||||
if (slot.sparams.dry_base < 1.0f)
|
||||
{
|
||||
slot.sparams.dry_base = default_sparams.dry_base;
|
||||
slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
||||
slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
||||
slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
||||
slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
||||
slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
||||
slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
||||
slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
||||
slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
||||
slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
||||
slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
||||
slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
||||
slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
||||
slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
||||
slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
||||
slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
||||
slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||
slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
||||
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
|
||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
|
||||
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
||||
|
||||
if (slot.params.sampling.dry_base < 1.0f) {
|
||||
slot.params.sampling.dry_base = defaults.sampling.dry_base;
|
||||
}
|
||||
|
||||
// sequence breakers for DRY
|
||||
@ -843,8 +931,8 @@ struct server_context {
|
||||
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||
|
||||
if (data.contains("dry_sequence_breakers")) {
|
||||
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (slot.sparams.dry_sequence_breakers.empty()) {
|
||||
slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (slot.params.sampling.dry_sequence_breakers.empty()) {
|
||||
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
@ -858,14 +946,14 @@ struct server_context {
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
slot.params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
} catch (const std::exception & e) {
|
||||
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
}
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
@ -875,10 +963,10 @@ struct server_context {
|
||||
}
|
||||
|
||||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
slot.params.sampling.logit_bias.clear();
|
||||
|
||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||
}
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
@ -899,12 +987,12 @@ struct server_context {
|
||||
if (el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
slot.params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
} else if (el[0].is_string()) {
|
||||
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
||||
for (auto tok : toks) {
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
slot.params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -935,16 +1023,16 @@ struct server_context {
|
||||
sampler_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
|
||||
slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
||||
} else if (samplers->is_string()){
|
||||
std::string sampler_string;
|
||||
for (const auto & name : *samplers) {
|
||||
sampler_string += name;
|
||||
}
|
||||
slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
|
||||
slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
||||
}
|
||||
} else {
|
||||
slot.sparams.samplers = default_sparams.samplers;
|
||||
slot.params.sampling.samplers = defaults.sampling.samplers;
|
||||
}
|
||||
}
|
||||
|
||||
@ -953,7 +1041,7 @@ struct server_context {
|
||||
common_sampler_free(slot.smpl);
|
||||
}
|
||||
|
||||
slot.smpl = common_sampler_init(model, slot.sparams);
|
||||
slot.smpl = common_sampler_init(model, slot.params.sampling);
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
@ -961,6 +1049,12 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
llama_batch_free(slot.batch_spec);
|
||||
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
@ -978,7 +1072,7 @@ struct server_context {
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
|
||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
@ -1043,7 +1137,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// check the limits
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
@ -1136,50 +1230,54 @@ struct server_context {
|
||||
|
||||
json get_formated_generation(const server_slot & slot) const {
|
||||
std::vector<std::string> samplers;
|
||||
samplers.reserve(slot.sparams.samplers.size());
|
||||
for (const auto & sampler : slot.sparams.samplers) {
|
||||
samplers.reserve(slot.params.sampling.samplers.size());
|
||||
for (const auto & sampler : slot.params.sampling.samplers) {
|
||||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||
}
|
||||
|
||||
return json {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_predict", slot.n_predict}, // Server configured n_predict
|
||||
{"model", params.model_alias},
|
||||
{"seed", slot.sparams.seed},
|
||||
{"model", params_base.model_alias},
|
||||
{"seed", slot.params.sampling.seed},
|
||||
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
|
||||
{"temperature", slot.sparams.temp},
|
||||
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
||||
{"top_k", slot.sparams.top_k},
|
||||
{"top_p", slot.sparams.top_p},
|
||||
{"min_p", slot.sparams.min_p},
|
||||
{"xtc_probability", slot.sparams.xtc_probability},
|
||||
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||
{"typical_p", slot.sparams.typ_p},
|
||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||
{"presence_penalty", slot.sparams.penalty_present},
|
||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||
{"dry_base", slot.sparams.dry_base},
|
||||
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"temperature", slot.params.sampling.temp},
|
||||
{"dynatemp_range", slot.params.sampling.dynatemp_range},
|
||||
{"dynatemp_exponent", slot.params.sampling.dynatemp_exponent},
|
||||
{"top_k", slot.params.sampling.top_k},
|
||||
{"top_p", slot.params.sampling.top_p},
|
||||
{"min_p", slot.params.sampling.min_p},
|
||||
{"xtc_probability", slot.params.sampling.xtc_probability},
|
||||
{"xtc_threshold", slot.params.sampling.xtc_threshold},
|
||||
{"typical_p", slot.params.sampling.typ_p},
|
||||
{"repeat_last_n", slot.params.sampling.penalty_last_n},
|
||||
{"repeat_penalty", slot.params.sampling.penalty_repeat},
|
||||
{"presence_penalty", slot.params.sampling.penalty_present},
|
||||
{"frequency_penalty", slot.params.sampling.penalty_freq},
|
||||
{"dry_multiplier", slot.params.sampling.dry_multiplier},
|
||||
{"dry_base", slot.params.sampling.dry_base},
|
||||
{"dry_allowed_length", slot.params.sampling.dry_allowed_length},
|
||||
{"dry_penalty_last_n", slot.params.sampling.dry_penalty_last_n},
|
||||
{"dry_sequence_breakers", slot.params.sampling.dry_sequence_breakers},
|
||||
{"mirostat", slot.params.sampling.mirostat},
|
||||
{"mirostat_tau", slot.params.sampling.mirostat_tau},
|
||||
{"mirostat_eta", slot.params.sampling.mirostat_eta},
|
||||
{"penalize_nl", slot.params.sampling.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"ignore_eos", slot.sparams.ignore_eos},
|
||||
{"ignore_eos", slot.params.sampling.ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
//{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
//{"logit_bias", slot.params.sampling.logit_bias},
|
||||
{"n_probs", slot.params.sampling.n_probs},
|
||||
{"min_keep", slot.params.sampling.min_keep},
|
||||
{"grammar", slot.params.sampling.grammar},
|
||||
{"samplers", samplers},
|
||||
{"speculative", slot.can_speculate()},
|
||||
{"speculative.n_max", slot.params.speculative.n_max},
|
||||
{"speculative.n_min", slot.params.speculative.n_min},
|
||||
{"speculative.p_min", slot.params.speculative.p_min},
|
||||
};
|
||||
}
|
||||
|
||||
@ -1216,7 +1314,7 @@ struct server_context {
|
||||
{"index", slot.index},
|
||||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||
@ -1249,7 +1347,7 @@ struct server_context {
|
||||
{"content", !slot.params.stream ? slot.generated_text : ""},
|
||||
{"id_slot", slot.id},
|
||||
{"stop", true},
|
||||
{"model", params.model_alias},
|
||||
{"model", params_base.model_alias},
|
||||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
@ -1265,7 +1363,7 @@ struct server_context {
|
||||
{"index", slot.index},
|
||||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
std::vector<completion_token_output> probs;
|
||||
if (!slot.params.stream && slot.stopped_word) {
|
||||
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
@ -1422,10 +1520,10 @@ struct server_context {
|
||||
data.at("input_prefix"),
|
||||
data.at("input_suffix"),
|
||||
data.at("input_extra"),
|
||||
params.n_batch,
|
||||
params.n_predict,
|
||||
params_base.n_batch,
|
||||
params_base.n_predict,
|
||||
slots[0].n_ctx, // TODO: there should be a better way
|
||||
params.spm_infill,
|
||||
params_base.spm_infill,
|
||||
tokenized_prompts[i]
|
||||
);
|
||||
create_task(data, tokens);
|
||||
@ -1798,7 +1896,7 @@ struct server_context {
|
||||
// TODO: simplify and improve
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
||||
if (!params.ctx_shift) {
|
||||
if (!params_base.ctx_shift) {
|
||||
// this check is redundant (for good)
|
||||
// we should never get here, because generation should already stopped in process_token()
|
||||
slot.release();
|
||||
@ -1864,7 +1962,7 @@ struct server_context {
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
@ -1917,7 +2015,7 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
if (!params_base.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
// TODO: there should be a separate parameter that control prompt truncation
|
||||
// context shift should be applied only during the generation phase
|
||||
@ -1963,11 +2061,11 @@ struct server_context {
|
||||
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
if (params_base.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
@ -1980,7 +2078,7 @@ struct server_context {
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
@ -2168,38 +2266,99 @@ struct server_context {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
llama_token id;
|
||||
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
{
|
||||
completion_token_output result;
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
result.tok = id;
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
|
||||
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
});
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
result.tok = id;
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
|
||||
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
});
|
||||
// check if the slot supports speculative decoding
|
||||
if (!slot.can_speculate()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = slot.params.speculative.n_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = -1;
|
||||
// construct the speculation batch
|
||||
common_batch_clear(slot.batch_spec);
|
||||
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
||||
}
|
||||
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.cache_tokens.push_back(id);
|
||||
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -2697,7 +2856,7 @@ int main(int argc, char ** argv) {
|
||||
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
json data = {
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params.n_parallel },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
||||
};
|
||||
|
||||
@ -2705,7 +2864,7 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params.endpoint_props) {
|
||||
if (!ctx_server.params_base.endpoint_props) {
|
||||
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
@ -2718,7 +2877,7 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
@ -2824,7 +2983,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
@ -3001,7 +3160,7 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params.reranking || ctx_server.params.embedding) {
|
||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user