mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
lookahead : generate and store n-grams
This commit is contained in:
parent
7c517e1722
commit
eb03b9ad69
@ -12,6 +12,21 @@ struct seq_ngram {
|
|||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ngram_container {
|
||||||
|
ngram_container(int n_vocab, int N, int G) {
|
||||||
|
cnt.resize(n_vocab);
|
||||||
|
head.resize(n_vocab);
|
||||||
|
tokens.resize(n_vocab * (N - 1)*G);
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_total = 0;
|
||||||
|
|
||||||
|
std::vector<int> cnt;
|
||||||
|
std::vector<int> head;
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
@ -99,10 +114,10 @@ int main(int argc, char ** argv) {
|
|||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
|
||||||
// verification n-grams
|
// verification n-grams
|
||||||
std::vector<seq_ngram> drafts(G);
|
std::vector<seq_ngram> ngrams(G);
|
||||||
|
|
||||||
// tokens for the past N - 1 Jacobi iterations
|
// tokens for the past N - 1 Jacobi iterations
|
||||||
// TODO: how to initialize?
|
std::vector<llama_token> tokens_j_prev(W);
|
||||||
std::vector<std::vector<llama_token>> tokens_j(N - 1);
|
std::vector<std::vector<llama_token>> tokens_j(N - 1);
|
||||||
for (int j = 0; j < N - 1; j++) {
|
for (int j = 0; j < N - 1; j++) {
|
||||||
tokens_j[j].resize(W);
|
tokens_j[j].resize(W);
|
||||||
@ -121,6 +136,8 @@ int main(int argc, char ** argv) {
|
|||||||
seq_id_all[i] = i;
|
seq_id_all[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ngram_container ngrams_observed(llama_n_vocab(model), N, G);
|
||||||
|
|
||||||
// debug
|
// debug
|
||||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
|
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
|
||||||
|
|
||||||
@ -188,8 +205,33 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// print known n-grams starting with token id
|
||||||
|
if (1) {
|
||||||
|
if (ngrams_observed.cnt[id] > 0) {
|
||||||
|
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
|
||||||
|
printf(" - ngram %2d: ", i);
|
||||||
|
|
||||||
|
const int idx = id*(N - 1)*G + i*(N - 1);
|
||||||
|
|
||||||
|
for (int j = 0; j < N - 1; j++) {
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
|
||||||
|
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// update Jacobi tokens (or whatever these are called)
|
// update Jacobi tokens (or whatever these are called)
|
||||||
{
|
{
|
||||||
|
for (int i = 0; i < W; i++) {
|
||||||
|
tokens_j_prev[i] = tokens_j[0][i];
|
||||||
|
}
|
||||||
|
|
||||||
for (int j = 0; j < N - 2; j++) {
|
for (int j = 0; j < N - 2; j++) {
|
||||||
tokens_j[j] = tokens_j[j + 1];
|
tokens_j[j] = tokens_j[j + 1];
|
||||||
}
|
}
|
||||||
@ -199,6 +241,40 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update observed ngrams
|
||||||
|
{
|
||||||
|
// the first token of the n-gram is determined by the index in the container so it is not stored
|
||||||
|
std::vector<llama_token> ngram(N - 1);
|
||||||
|
|
||||||
|
// n-gram generation
|
||||||
|
for (int f = 0; f < W; ++f) {
|
||||||
|
std::function<void(int)> rec = [&](int j) {
|
||||||
|
if (j == N - 1) {
|
||||||
|
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
||||||
|
const int head = ngrams_observed.head[ft];
|
||||||
|
const int idx = ft*(N - 1)*G + head*(N - 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < N - 1; i++) {
|
||||||
|
ngrams_observed.tokens[idx + i] = ngram[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
|
||||||
|
ngrams_observed.head[ft] = (head + 1) % G;
|
||||||
|
|
||||||
|
ngrams_observed.n_total++;
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngram[j] = tokens_j[j][f];
|
||||||
|
|
||||||
|
rec(j + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
rec(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// verification
|
// verification
|
||||||
// TODO
|
// TODO
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user