lookahead : filter repeating n-grams

This commit is contained in:
Georgi Gerganov 2023-11-25 17:02:56 +02:00
parent 61d039727a
commit 6eb5166e5a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -38,9 +38,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int W = 10; // lookahead window const int W = 15; // lookahead window
const int N = 8; // n-gram size const int N = 5; // n-gram size
const int G = 10; // max verification n-grams const int G = 15; // max verification n-grams
const bool dump_kv_cache = params.dump_kv_cache; const bool dump_kv_cache = params.dump_kv_cache;
@ -128,8 +128,8 @@ int main(int argc, char ** argv) {
for (int i = 0; i < W; i++) { for (int i = 0; i < W; i++) {
// initialize randomly from the prompt tokens // initialize randomly from the prompt tokens
//tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
tokens_j[j][i] = 100 + i; //tokens_j[j][i] = 100 + i;
} }
} }
@ -234,6 +234,8 @@ int main(int argc, char ** argv) {
if (ngrams_cur[g].active) { if (ngrams_cur[g].active) {
i_batch = ngrams_cur[g].i_batch[v]; i_batch = ngrams_cur[g].i_batch[v];
seq_id_best = ngrams_cur[g].seq_id; seq_id_best = ngrams_cur[g].seq_id;
++n_accept;
break; break;
} }
} }
@ -281,14 +283,13 @@ int main(int argc, char ** argv) {
} else { } else {
if (id != ngrams_cur[g].tokens[v + 1]) { if (id != ngrams_cur[g].tokens[v + 1]) {
ngrams_cur[g].active = false; ngrams_cur[g].active = false;
} else {
} }
} }
} }
} }
// print known n-grams starting with token id // print known n-grams starting with token id (debug)
if (0) { if (0 && v == 0) {
if (ngrams_observed.cnt[id] > 0) { if (ngrams_observed.cnt[id] > 0) {
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
} }
@ -326,8 +327,8 @@ int main(int argc, char ** argv) {
} else { } else {
for (int i = 0; i < W; i++) { for (int i = 0; i < W; i++) {
// random init // random init
//tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
tokens_j[N - 2][i] = tokens_j[0][i]; //tokens_j[N - 2][i] = tokens_j[0][i];
} }
} }
} }
@ -340,11 +341,38 @@ int main(int argc, char ** argv) {
// n-gram generation // n-gram generation
// ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
for (int f = 0; f < W; ++f) { for (int f = 0; f < W; ++f) {
const int ft = tokens_j_prev[f]; // first token of the n-gram
for (int j = 0; j < N - 1; ++j) { for (int j = 0; j < N - 1; ++j) {
ngram[j] = tokens_j[j][f]; ngram[j] = tokens_j[j][f];
}; }
// filter-out repeating n-grams
{
bool is_unique = true;
for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
const int idx = ft*(N - 1)*G + k*(N - 1);
bool is_match = true;
for (int j = 0; j < N - 1; ++j) {
if (ngrams_observed.tokens[idx + j] != ngram[j]) {
is_match = false;
break;
}
}
if (is_match) {
is_unique = false;
break;
}
}
if (!is_unique) {
continue;
}
}
const int ft = tokens_j_prev[f]; // first token of the n-gram
const int head = ngrams_observed.head[ft]; const int head = ngrams_observed.head[ft];
const int idx = ft*(N - 1)*G + head*(N - 1); const int idx = ft*(N - 1)*G + head*(N - 1);
@ -360,6 +388,10 @@ int main(int argc, char ** argv) {
} }
} }
if (n_predict > params.n_predict || has_eos) {
break;
}
llama_kv_cache_seq_rm(ctx, -1, n_past, -1); llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) { if (seq_id_best != 0) {