mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
lookahead : use loop instead recursion to generate n-grams
This commit is contained in:
parent
eb03b9ad69
commit
1b2e0bc3e6
@ -169,9 +169,11 @@ int main(int argc, char ** argv) {
|
|||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
llama_batch_add(batch, id, n_past, seq_id_all, true);
|
llama_batch_add(batch, id, n_past, seq_id_all, true);
|
||||||
|
|
||||||
for (int i = 1; i < W; i++) {
|
for (int i = 1; i < W; i++) {
|
||||||
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 1; j < N - 1; j++) {
|
for (int j = 1; j < N - 1; j++) {
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
||||||
@ -248,30 +250,22 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// n-gram generation
|
// n-gram generation
|
||||||
for (int f = 0; f < W; ++f) {
|
for (int f = 0; f < W; ++f) {
|
||||||
std::function<void(int)> rec = [&](int j) {
|
for (int j = 0; j < N - 1; ++j) {
|
||||||
if (j == N - 1) {
|
|
||||||
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
|
||||||
const int head = ngrams_observed.head[ft];
|
|
||||||
const int idx = ft*(N - 1)*G + head*(N - 1);
|
|
||||||
|
|
||||||
for (int i = 0; i < N - 1; i++) {
|
|
||||||
ngrams_observed.tokens[idx + i] = ngram[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
|
|
||||||
ngrams_observed.head[ft] = (head + 1) % G;
|
|
||||||
|
|
||||||
ngrams_observed.n_total++;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ngram[j] = tokens_j[j][f];
|
ngram[j] = tokens_j[j][f];
|
||||||
|
|
||||||
rec(j + 1);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
rec(0);
|
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
||||||
|
const int head = ngrams_observed.head[ft];
|
||||||
|
const int idx = ft*(N - 1)*G + head*(N - 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < N - 1; i++) {
|
||||||
|
ngrams_observed.tokens[idx + i] = ngram[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
|
||||||
|
ngrams_observed.head[ft] = (head + 1) % G;
|
||||||
|
|
||||||
|
ngrams_observed.n_total++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user