mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 20:22:25 +01:00
speculative : add infill mode
ggml-ci
This commit is contained in:
parent
0eb4e12bee
commit
b83cae088c
@ -11,7 +11,9 @@
|
||||
|
||||
struct common_speculative {
|
||||
struct llama_context * ctx;
|
||||
|
||||
struct common_sampler * smpl;
|
||||
struct common_sampler * smpl_infill;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt;
|
||||
@ -20,14 +22,26 @@ struct common_speculative {
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_dft) {
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .smpl_infill = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
#if 0
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
@ -41,28 +55,15 @@ struct common_speculative * common_speculative_init(
|
||||
COMMON_SAMPLER_TYPE_INFILL,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#else
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec) {
|
||||
common_sampler_free(spec->smpl);
|
||||
common_sampler_free(spec->smpl_infill);
|
||||
|
||||
llama_batch_free(spec->batch);
|
||||
|
||||
@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
llama_token id_last) {
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx = spec->ctx;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
|
||||
auto & prompt = spec->prompt;
|
||||
|
||||
int reuse_i = 0;
|
||||
|
@ -10,6 +10,8 @@ struct common_speculative_params {
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||
|
||||
bool infill = false; // use infill sampling (useful for FIM)
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
@ -2315,6 +2315,7 @@ struct server_context {
|
||||
params_spec.n_draft = slot.params.speculative.n_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user