diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index e5fa37b81..c45184b14 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -169,9 +169,11 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); llama_batch_add(batch, id, n_past, seq_id_all, true); + for (int i = 1; i < W; i++) { llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); } + for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); @@ -248,30 +250,22 @@ int main(int argc, char ** argv) { // n-gram generation for (int f = 0; f < W; ++f) { - std::function rec = [&](int j) { - if (j == N - 1) { - const int ft = tokens_j_prev[f]; // first token of the n-gram - const int head = ngrams_observed.head[ft]; - const int idx = ft*(N - 1)*G + head*(N - 1); - - for (int i = 0; i < N - 1; i++) { - ngrams_observed.tokens[idx + i] = ngram[i]; - } - - ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); - ngrams_observed.head[ft] = (head + 1) % G; - - ngrams_observed.n_total++; - - return; - } - + for (int j = 0; j < N - 1; ++j) { ngram[j] = tokens_j[j][f]; - - rec(j + 1); }; - rec(0); + const int ft = tokens_j_prev[f]; // first token of the n-gram + const int head = ngrams_observed.head[ft]; + const int idx = ft*(N - 1)*G + head*(N - 1); + + for (int i = 0; i < N - 1; i++) { + ngrams_observed.tokens[idx + i] = ngram[i]; + } + + ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); + ngrams_observed.head[ft] = (head + 1) % G; + + ngrams_observed.n_total++; } }