speculative : add infill mode

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-26 11:14:17 +02:00
parent 0eb4e12bee
commit b83cae088c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 26 additions and 22 deletions

View File

@ -11,7 +11,9 @@
struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
struct common_sampler * smpl_infill;
llama_batch batch;
llama_tokens prompt;
@ -22,12 +24,24 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .smpl_infill = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
{
common_params_sampling params;
params.no_perf = false;
@ -41,28 +55,15 @@ struct common_speculative * common_speculative_init(
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
}
#else
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
return result;
}
void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
common_sampler_free(spec->smpl_infill);
llama_batch_free(spec->batch);
@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
auto & prompt = spec->prompt;
int reuse_i = 0;

View File

@ -10,6 +10,8 @@ struct common_speculative_params {
int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
bool infill = false; // use infill sampling (useful for FIM)
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);

View File

@ -2315,6 +2315,7 @@ struct server_context {
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);