From 6d341ab6c53cd51f2921d986d0090cc8b049b39a Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Tue, 5 Mar 2024 03:24:00 +0900 Subject: [PATCH] speculative : implement stochastic speculative sampling (#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README --- common/common.cpp | 7 - common/common.h | 3 +- common/sampling.cpp | 79 ++++++++++ common/sampling.h | 7 + examples/speculative/README.md | 1 + examples/speculative/speculative.cpp | 224 ++++++++++++++++++++------- 6 files changed, 260 insertions(+), 61 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index dbe7e9229..036a98134 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-accept" || arg == "-pa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_accept = std::stof(argv[i]); } else if (arg == "--p-split" || arg == "-ps") { if (++i >= argc) { invalid_param = true; @@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); diff --git a/common/common.h b/common/common.h index b2868833b..977ce419f 100644 --- a/common/common.h +++ b/common/common.h @@ -53,11 +53,10 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 8; // number of tokens to draft during speculative decoding + int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/sampling.cpp b/common/sampling.cpp index e67096bea..823031feb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl( return id; } +static llama_token_data_array llama_sample_probability_distribution_impl( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const llama_sampling_params & params = ctx_sampling->params; + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const bool penalize_nl = params.penalize_nl; + + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + + // Get a pointer to the logits + float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector original_logits; + + // apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + + // apply penalties + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + // apply grammar checks + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + } + + llama_sample_softmax(ctx_main, &cur_p); + return cur_p; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -304,6 +375,14 @@ llama_token llama_sampling_sample( return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); } +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 95d875394..48b2459d1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,6 +131,13 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); +// returns the probability that token of given id will be sampled +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/examples/speculative/README.md b/examples/speculative/README.md index 814efa592..a6608c5fe 100644 --- a/examples/speculative/README.md +++ b/examples/speculative/README.md @@ -6,3 +6,4 @@ More info: - https://github.com/ggerganov/llama.cpp/pull/2926 - https://github.com/ggerganov/llama.cpp/pull/3624 +- https://github.com/ggerganov/llama.cpp/pull/5625 diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 3848791d4..85bc0a762 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -18,6 +19,7 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; + std::vector> dists; struct llama_sampling_context * ctx_sampling; }; @@ -37,12 +39,15 @@ int main(int argc, char ** argv) { // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; - // probability threshold for accepting a token from the draft model - const float p_accept = params.p_accept; - // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + std::default_random_engine rng(params.seed); + std::uniform_real_distribution<> u_dist; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -166,7 +171,9 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + if (params.sparams.temp == 0) { + params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); @@ -182,12 +189,15 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt[0] = 0; while (true) { + std::set active_seqs = {}; + // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + active_seqs.insert(s); const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); @@ -196,48 +206,156 @@ int main(int argc, char ** argv) { int i_dft = 0; int s_keep = 0; + llama_token token_id; + std::string token_str; + + // loop until we fail to accept a drafted token or we run out of drafted tokens while (true) { - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - - // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - - llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); - - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); - - const std::string token_str = llama_token_to_piece(ctx_tgt, id); - - if (!params.use_color) { - printf("%s", token_str.c_str()); - } - - if (id == llama_token_eos(model_tgt)) { - has_eos = true; - } - - ++n_predict; // check if the target token matches any of the drafts + // for stochastic sampling, attempt to match the token with the drafted tokens { - bool matches = false; + bool accept = false; + if (params.sparams.temp > 0) { + // stochastic verification - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; + llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + float p_tgt = 0, p_dft = 0; + + // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + std::uniform_int_distribution u_int_dist(0, active_seqs.size() - 1); + int s = *std::next(active_seqs.begin(), u_int_dist(rng)); + if (i_dft >= (int) drafts[s].tokens.size()) { + drafts[s].active = false; + active_seqs.erase(s); + continue; + } + if (accept) { + // if we already accepted a token, we can skip the rest + if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { + drafts[s].active = false; + active_seqs.erase(s); + } + continue; + } + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); + float r = u_dist(rng); + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + // acquire the token probabilities assigned by the draft and target models + for (size_t i = 0; i < dist_tgt.size; i++) { + if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { + p_tgt = dist_tgt.data[i].p; + } + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + } + if (p_tgt && p_dft) { + break; + } + } + LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + if (r <= p_tgt / p_dft) { + s_keep = s; + accept = true; + token_id = drafts[s].tokens[i_dft]; + token_str = llama_token_to_piece(ctx_tgt, token_id); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + break; + } else { + LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + drafts[s].active = false; + + // calculate residual probability + GGML_ASSERT(dist_tgt.sorted); + GGML_ASSERT(dist_dft.sorted); + float sum_probs = 0.0f; + + // sort dist by id + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + sum_probs += dist_tgt.data[i].p; + } + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p /= sum_probs; + } + + // sort dist_tgt by p desc + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.p > b.p; + }); + } + + active_seqs.erase(s); + for(int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } + if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + // synchronize active status for sequences with the same drafted token + drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } + } + } } - if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + if (!accept) { + // all drafted tokens were rejected + // sample from the target model + LOG("all drafted tokens were rejected, sampling from residual distribution\n"); + token_id = llama_sample_token(ctx_tgt, &dist_tgt); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_str = llama_token_to_piece(ctx_tgt, token_id); + } - s_keep = s; - matches = true; - } else { - drafts[s].active = false; + } else { + // greedy verification + + // sample from the target model + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + + token_str = llama_token_to_piece(ctx_tgt, token_id); + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + + s_keep = s; + accept = true; + } else { + drafts[s].active = false; + } } } - if (matches) { + if (token_id == llama_token_eos(model_tgt)) { + has_eos = true; + } + ++n_predict; + + if (accept) { ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -245,17 +363,21 @@ int main(int argc, char ** argv) { if (params.use_color) { // Color token according to its origin sequence printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); - fflush(stdout); + } else { + printf("%s", token_str.c_str()); } + fflush(stdout); continue; + } else { + printf("%s", token_str.c_str()); + fflush(stdout); + break; } } - if (params.use_color) { - printf("%s", token_str.c_str()); - } - fflush(stdout); + } - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + { + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); // TODO: simplify { @@ -275,21 +397,21 @@ int main(int argc, char ** argv) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); + drafts[s].dists.clear(); } // note: will be erased after the speculation phase - drafts[0].tokens.push_back(id); + drafts[0].tokens.push_back(token_id); + drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode (ctx_dft, batch_dft); + llama_decode(ctx_dft, batch_dft); ++n_past_dft; - - break; } if (n_predict > params.n_predict || has_eos) { @@ -334,12 +456,6 @@ int main(int argc, char ** argv) { k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); - drafts[s].drafting = false; - continue; - } - std::vector sa(1, s); // attempt to split the branch if the probability is high enough @@ -367,6 +483,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].dists = drafts[s].dists; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -389,6 +506,8 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back(cur_p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -440,6 +559,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); } }