llama.cpp/examples/speculative/speculative.cpp

605 lines
23 KiB
C++
Raw Normal View History

#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <set>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
bool active = false;
bool drafting = false;
bool skip = false;
int i_batch_dft = 0;
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
std::vector<std::vector<llama_token_data>> dists;
struct llama_sampling_context * ctx_sampling;
};
int main(int argc, char ** argv) {
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.model_draft.empty()) {
fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
return 1;
}
// max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel;
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
std::default_random_engine rng(params.seed);
std::uniform_real_distribution<> u_dist;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
// init llama.cpp
ggml : add numa options (#5377) * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverted Makefile * Fixed include * Removed sched.h from ggml.h, moved ggml_get_numa_affinity into ggml.c, removed trailing whitespace and fixed up a few inconsistent variables * removed trailing whitespace * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverting Makefile * Fixed a number of issues with the move from BOOL to ggml_numa_strategies. Added a note about mirror mode note being implemented yet * Removing MIRROR_MODE code for this PR * Removing last bit of MIRROR_MODE code for this PR * Removing unneeded branch in server.cpp example and moving get_numa_affinity and making it static * Fixed lingering init_llama_backend() bool calls in tests and examples * Remote enum llama_numa_strategies * Revert bad merge with dynatemp flags * add missing enum ggml_numa_strategies declaration and revert sync problem with master * add missing enum ggml_numa_strategies declaration * fixed ggml_init_numa variable * Update ggml.h Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * Update READMEs with info about numa flags, change INTERLEAVE strategy name to DISTRIBUTE everywhere, implement the improved distribution strategy from @rankaiyx, fix a spelling mistake and un-merge some bad merges * split numa init out from llama_backend_init and created llama_numa_init. Updated all code paths and samples * Fix up some boolean vs enum comparisons * Added #ifdefs for non-Linux OS that don't have cpu_set_t datatype * Update ggml.h Align enum values Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c Remove whitespace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c align paremeters Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update examples/server/server.cpp remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/common.cpp Remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * unified ggml_numa_strategy enum and fixed text alignment in server.cpp example * Update ggml.c simplified return for platforms without NUMA support Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * removed redundant else from cli argument processing of --numa * whitespace --------- Co-authored-by: root <root@nenya.lothlorien.ca> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
2024-02-16 10:31:07 +01:00
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.n_threads_draft > 0) {
params.n_threads = params.n_threads_draft;
}
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return 1;
}
}
}
// Tokenize the prompt
const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
LOG("add_bos tgt: %d\n", add_bos_tgt);
const bool add_bos_dft = llama_should_add_bos_token(model_dft);
LOG("add_bos dft: %d\n", add_bos_dft);
if (add_bos_tgt != add_bos_dft) {
fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
return 1;
}
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;
if ((int) inp.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
return 1;
}
fprintf(stderr, "\n\n");
for (auto id : inp) {
fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
}
fflush(stderr);
const int n_input = inp.size();
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;
int n_past_tgt = inp.size();
int n_past_dft = inp.size();
// used to determine end of generation
bool has_eos = false;
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
if (params.sparams.temp == 0) {
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
}
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
}
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
const auto t_dec_start = ggml_time_us();
// sample from the last token of the prompt
drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0;
while (true) {
std::set<int> active_seqs = {};
// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
active_seqs.insert(s);
const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
}
int i_dft = 0;
int s_keep = 0;
llama_token token_id;
std::string token_str;
// loop until we fail to accept a drafted token or we run out of drafted tokens
while (true) {
// check if the target token matches any of the drafts
// for stochastic sampling, attempt to match the token with the drafted tokens
{
bool accept = false;
if (params.sparams.temp > 0) {
// stochastic verification
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
while (active_seqs.size() > 0) {
// randomly select a sequence to verify from active sequences
std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
if (i_dft >= (int) drafts[s].tokens.size()) {
drafts[s].active = false;
active_seqs.erase(s);
continue;
}
if (accept) {
// if we already accepted a token, we can skip the rest
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
drafts[s].active = false;
active_seqs.erase(s);
}
continue;
}
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p;
}
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
p_dft = dist_dft.data[i].p;
}
if (p_tgt && p_dft) {
break;
}
}
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
if (r <= p_tgt / p_dft) {
s_keep = s;
accept = true;
token_id = drafts[s].tokens[i_dft];
token_str = llama_token_to_piece(ctx_tgt, token_id);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
break;
} else {
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
drafts[s].active = false;
// calculate residual probability
GGML_ASSERT(dist_tgt.sorted);
GGML_ASSERT(dist_dft.sorted);
float sum_probs = 0.0f;
// sort dist by id
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.id < b.id;
});
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.id < b.id;
});
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
sum_probs += dist_tgt.data[i].p;
}
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p /= sum_probs;
}
// sort dist_tgt by p desc
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.p > b.p;
});
}
active_seqs.erase(s);
for(int i = 0; i < n_seq_dft; i++) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
active_seqs.erase(s);
}
}
}
}
if (!accept) {
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
} else {
// greedy verification
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
token_str = llama_token_to_piece(ctx_tgt, token_id);
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
s_keep = s;
accept = true;
} else {
drafts[s].active = false;
}
}
}
if (token_id == llama_token_eos(model_tgt)) {
has_eos = true;
}
++n_predict;
if (accept) {
++n_accept;
++n_past_tgt;
++n_past_dft;
++i_dft;
if (params.use_color) {
// Color token according to its origin sequence
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
} else {
printf("%s", token_str.c_str());
}
fflush(stdout);
continue;
} else {
printf("%s", token_str.c_str());
fflush(stdout);
break;
}
}
}
{
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
2023-10-18 17:49:40 +02:00
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
drafts[s].dists.clear();
}
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(token_id);
drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
++n_past_dft;
}
if (n_predict > params.n_predict || has_eos) {
break;
}
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].drafting = false;
}
drafts[0].active = true;
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0;
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
}
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].drafting || drafts[s].skip) {
continue;
}
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
const auto & cur_p = drafts[s].ctx_sampling->cur;
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
}
std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
break;
}
}
}
// copy the draft state
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].dists = drafts[s].dists;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
sa.push_back(n_seq_cur);
n_seq_cur++;
} else {
break;
}
}
// add drafted token for each sequence
for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p[is].id;
const int s = sa[is];
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id);
// save cur_p.data into drafts[s].dists
drafts[s].dists.push_back(cur_p);
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
2023-10-18 17:49:40 +02:00
if (batch_tgt.n_tokens > n_draft) {
drafts[s].drafting = false;
}
}
}
// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;
if (batch_tgt.n_tokens > n_draft) {
break;
}
}
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
// the first token is always proposed by the target model before the speculation loop so we erase it here
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
drafts[s].tokens.erase(drafts[s].tokens.begin());
drafts[s].dists.erase(drafts[s].dists.begin());
}
}
auto t_dec_end = ggml_time_us();
LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ndraft:\n");
llama_print_timings(ctx_dft);
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt);
llama_sampling_free(ctx_sampling);
for (int s = 0; s < n_seq_dft; ++s) {
llama_sampling_free(drafts[s].ctx_sampling);
}
llama_batch_free(batch_dft);
llama_free(ctx_tgt);
llama_free_model(model_tgt);
llama_free(ctx_dft);
llama_free_model(model_dft);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}