mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-23 21:17:54 +01:00
speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README
This commit is contained in:
parent
4ffcdce2ff
commit
6d341ab6c5
@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.n_sequences = std::stoi(argv[i]);
|
||||
} else if (arg == "--p-accept" || arg == "-pa") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.p_accept = std::stof(argv[i]);
|
||||
} else if (arg == "--p-split" || arg == "-ps") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
|
||||
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
||||
|
@ -53,11 +53,10 @@ struct gpt_params {
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
|
||||
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||
int32_t n_sequences = 1; // number of sequences to decode
|
||||
float p_accept = 0.5f; // speculative decoding accept probability
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
|
||||
return id;
|
||||
}
|
||||
|
||||
static llama_token_data_array llama_sample_probability_distribution_impl(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||
const float penalty_repeat = params.penalty_repeat;
|
||||
const float penalty_freq = params.penalty_freq;
|
||||
const float penalty_present = params.penalty_present;
|
||||
const bool penalize_nl = params.penalize_nl;
|
||||
|
||||
auto & prev = ctx_sampling->prev;
|
||||
auto & cur = ctx_sampling->cur;
|
||||
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
// Declare original_logits at the beginning of the function scope
|
||||
std::vector<float> original_logits;
|
||||
|
||||
// apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
}
|
||||
|
||||
if (ctx_cfg) {
|
||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
cur.clear();
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
// apply penalties
|
||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||
if (penalty_tokens_used_size) {
|
||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||
cur_p.data[idx].logit = nl_logit;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply grammar checks
|
||||
if (ctx_sampling->grammar != NULL) {
|
||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||
}
|
||||
|
||||
llama_sample_softmax(ctx_main, &cur_p);
|
||||
return cur_p;
|
||||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
|
||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
||||
}
|
||||
|
||||
llama_token_data_array llama_sampling_probability_distribution(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
|
||||
}
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
|
@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx = 0);
|
||||
|
||||
// returns the probability that token of given id will be sampled
|
||||
llama_token_data_array llama_sampling_probability_distribution(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx = 0);
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
|
@ -6,3 +6,4 @@ More info:
|
||||
|
||||
- https://github.com/ggerganov/llama.cpp/pull/2926
|
||||
- https://github.com/ggerganov/llama.cpp/pull/3624
|
||||
- https://github.com/ggerganov/llama.cpp/pull/5625
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
@ -18,6 +19,7 @@ struct seq_draft {
|
||||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
std::vector<std::vector<llama_token_data>> dists;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling;
|
||||
};
|
||||
@ -37,12 +39,15 @@ int main(int argc, char ** argv) {
|
||||
// max number of parallel drafting sequences (i.e. tree branches)
|
||||
const int n_seq_dft = params.n_parallel;
|
||||
|
||||
// probability threshold for accepting a token from the draft model
|
||||
const float p_accept = params.p_accept;
|
||||
|
||||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||
const float p_split = params.p_split;
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
std::default_random_engine rng(params.seed);
|
||||
std::uniform_real_distribution<> u_dist;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("speculative", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
@ -166,7 +171,9 @@ int main(int argc, char ** argv) {
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
||||
if (params.sparams.temp == 0) {
|
||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
@ -182,12 +189,15 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].i_batch_tgt[0] = 0;
|
||||
|
||||
while (true) {
|
||||
std::set<int> active_seqs = {};
|
||||
|
||||
// print current draft sequences
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
active_seqs.insert(s);
|
||||
const auto & tokens = drafts[s].tokens;
|
||||
|
||||
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
||||
@ -196,48 +206,156 @@ int main(int argc, char ** argv) {
|
||||
int i_dft = 0;
|
||||
int s_keep = 0;
|
||||
|
||||
llama_token token_id;
|
||||
std::string token_str;
|
||||
|
||||
// loop until we fail to accept a drafted token or we run out of drafted tokens
|
||||
while (true) {
|
||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
||||
|
||||
if (!params.use_color) {
|
||||
printf("%s", token_str.c_str());
|
||||
}
|
||||
|
||||
if (id == llama_token_eos(model_tgt)) {
|
||||
has_eos = true;
|
||||
}
|
||||
|
||||
++n_predict;
|
||||
|
||||
// check if the target token matches any of the drafts
|
||||
// for stochastic sampling, attempt to match the token with the drafted tokens
|
||||
{
|
||||
bool matches = false;
|
||||
bool accept = false;
|
||||
if (params.sparams.temp > 0) {
|
||||
// stochastic verification
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
float p_tgt = 0, p_dft = 0;
|
||||
|
||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||
|
||||
while (active_seqs.size() > 0) {
|
||||
// randomly select a sequence to verify from active sequences
|
||||
std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1);
|
||||
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
|
||||
if (i_dft >= (int) drafts[s].tokens.size()) {
|
||||
drafts[s].active = false;
|
||||
active_seqs.erase(s);
|
||||
continue;
|
||||
}
|
||||
if (accept) {
|
||||
// if we already accepted a token, we can skip the rest
|
||||
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
||||
drafts[s].active = false;
|
||||
active_seqs.erase(s);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
||||
float r = u_dist(rng);
|
||||
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
||||
// acquire the token probabilities assigned by the draft and target models
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||
p_tgt = dist_tgt.data[i].p;
|
||||
}
|
||||
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
||||
p_dft = dist_dft.data[i].p;
|
||||
}
|
||||
if (p_tgt && p_dft) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
||||
if (r <= p_tgt / p_dft) {
|
||||
s_keep = s;
|
||||
accept = true;
|
||||
token_id = drafts[s].tokens[i_dft];
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
|
||||
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||
break;
|
||||
} else {
|
||||
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
||||
drafts[s].active = false;
|
||||
|
||||
// calculate residual probability
|
||||
GGML_ASSERT(dist_tgt.sorted);
|
||||
GGML_ASSERT(dist_dft.sorted);
|
||||
float sum_probs = 0.0f;
|
||||
|
||||
// sort dist by id
|
||||
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||
return a.id < b.id;
|
||||
});
|
||||
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||
return a.id < b.id;
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
||||
sum_probs += dist_tgt.data[i].p;
|
||||
}
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
dist_tgt.data[i].p /= sum_probs;
|
||||
}
|
||||
|
||||
// sort dist_tgt by p desc
|
||||
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||
return a.p > b.p;
|
||||
});
|
||||
}
|
||||
|
||||
active_seqs.erase(s);
|
||||
for(int i = 0; i < n_seq_dft; i++) {
|
||||
if (i == s) {
|
||||
continue;
|
||||
}
|
||||
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
// synchronize active status for sequences with the same drafted token
|
||||
drafts[i].active = drafts[i].active && accept;
|
||||
if (!drafts[i].active) {
|
||||
active_seqs.erase(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
||||
if (!accept) {
|
||||
// all drafted tokens were rejected
|
||||
// sample from the target model
|
||||
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
}
|
||||
|
||||
s_keep = s;
|
||||
matches = true;
|
||||
} else {
|
||||
drafts[s].active = false;
|
||||
} else {
|
||||
// greedy verification
|
||||
|
||||
// sample from the target model
|
||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||
|
||||
s_keep = s;
|
||||
accept = true;
|
||||
} else {
|
||||
drafts[s].active = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (matches) {
|
||||
if (token_id == llama_token_eos(model_tgt)) {
|
||||
has_eos = true;
|
||||
}
|
||||
++n_predict;
|
||||
|
||||
if (accept) {
|
||||
++n_accept;
|
||||
++n_past_tgt;
|
||||
++n_past_dft;
|
||||
@ -245,17 +363,21 @@ int main(int argc, char ** argv) {
|
||||
if (params.use_color) {
|
||||
// Color token according to its origin sequence
|
||||
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
||||
fflush(stdout);
|
||||
} else {
|
||||
printf("%s", token_str.c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
continue;
|
||||
} else {
|
||||
printf("%s", token_str.c_str());
|
||||
fflush(stdout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (params.use_color) {
|
||||
printf("%s", token_str.c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
||||
{
|
||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
||||
|
||||
// TODO: simplify
|
||||
{
|
||||
@ -275,21 +397,21 @@ int main(int argc, char ** argv) {
|
||||
drafts[s].active = false;
|
||||
drafts[s].tokens.clear();
|
||||
drafts[s].i_batch_tgt.clear();
|
||||
drafts[s].dists.clear();
|
||||
}
|
||||
// note: will be erased after the speculation phase
|
||||
drafts[0].tokens.push_back(id);
|
||||
drafts[0].tokens.push_back(token_id);
|
||||
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
llama_batch_clear(batch_dft);
|
||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode (ctx_dft, batch_dft);
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if (n_predict > params.n_predict || has_eos) {
|
||||
@ -334,12 +456,6 @@ int main(int argc, char ** argv) {
|
||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||
}
|
||||
|
||||
if (cur_p[0].p < p_accept) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> sa(1, s);
|
||||
|
||||
// attempt to split the branch if the probability is high enough
|
||||
@ -367,6 +483,7 @@ int main(int argc, char ** argv) {
|
||||
drafts[n_seq_cur].skip = true;
|
||||
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].dists = drafts[s].dists;
|
||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
@ -389,6 +506,8 @@ int main(int argc, char ** argv) {
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
// save cur_p.data into drafts[s].dists
|
||||
drafts[s].dists.push_back(cur_p);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
@ -440,6 +559,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||
drafts[s].dists.erase(drafts[s].dists.begin());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user