mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
lookahead : filter repeating n-grams
This commit is contained in:
parent
61d039727a
commit
6eb5166e5a
@ -38,9 +38,9 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int W = 10; // lookahead window
|
||||
const int N = 8; // n-gram size
|
||||
const int G = 10; // max verification n-grams
|
||||
const int W = 15; // lookahead window
|
||||
const int N = 5; // n-gram size
|
||||
const int G = 15; // max verification n-grams
|
||||
|
||||
const bool dump_kv_cache = params.dump_kv_cache;
|
||||
|
||||
@ -128,8 +128,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
for (int i = 0; i < W; i++) {
|
||||
// initialize randomly from the prompt tokens
|
||||
//tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
|
||||
tokens_j[j][i] = 100 + i;
|
||||
tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
|
||||
//tokens_j[j][i] = 100 + i;
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,6 +234,8 @@ int main(int argc, char ** argv) {
|
||||
if (ngrams_cur[g].active) {
|
||||
i_batch = ngrams_cur[g].i_batch[v];
|
||||
seq_id_best = ngrams_cur[g].seq_id;
|
||||
|
||||
++n_accept;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -281,14 +283,13 @@ int main(int argc, char ** argv) {
|
||||
} else {
|
||||
if (id != ngrams_cur[g].tokens[v + 1]) {
|
||||
ngrams_cur[g].active = false;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// print known n-grams starting with token id
|
||||
if (0) {
|
||||
// print known n-grams starting with token id (debug)
|
||||
if (0 && v == 0) {
|
||||
if (ngrams_observed.cnt[id] > 0) {
|
||||
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
|
||||
}
|
||||
@ -326,8 +327,8 @@ int main(int argc, char ** argv) {
|
||||
} else {
|
||||
for (int i = 0; i < W; i++) {
|
||||
// random init
|
||||
//tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
|
||||
tokens_j[N - 2][i] = tokens_j[0][i];
|
||||
tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
|
||||
//tokens_j[N - 2][i] = tokens_j[0][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -340,11 +341,38 @@ int main(int argc, char ** argv) {
|
||||
// n-gram generation
|
||||
// ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
|
||||
for (int f = 0; f < W; ++f) {
|
||||
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
||||
|
||||
for (int j = 0; j < N - 1; ++j) {
|
||||
ngram[j] = tokens_j[j][f];
|
||||
};
|
||||
}
|
||||
|
||||
// filter-out repeating n-grams
|
||||
{
|
||||
bool is_unique = true;
|
||||
|
||||
for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
|
||||
const int idx = ft*(N - 1)*G + k*(N - 1);
|
||||
|
||||
bool is_match = true;
|
||||
for (int j = 0; j < N - 1; ++j) {
|
||||
if (ngrams_observed.tokens[idx + j] != ngram[j]) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_match) {
|
||||
is_unique = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_unique) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
||||
const int head = ngrams_observed.head[ft];
|
||||
const int idx = ft*(N - 1)*G + head*(N - 1);
|
||||
|
||||
@ -360,6 +388,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (n_predict > params.n_predict || has_eos) {
|
||||
break;
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
|
||||
|
||||
if (seq_id_best != 0) {
|
||||
|
Loading…
Reference in New Issue
Block a user