diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 33af03a3e..6f841fff4 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -38,9 +38,9 @@ int main(int argc, char ** argv) { return 1; } - const int W = 10; // lookahead window - const int N = 8; // n-gram size - const int G = 10; // max verification n-grams + const int W = 15; // lookahead window + const int N = 5; // n-gram size + const int G = 15; // max verification n-grams const bool dump_kv_cache = params.dump_kv_cache; @@ -128,8 +128,8 @@ int main(int argc, char ** argv) { for (int i = 0; i < W; i++) { // initialize randomly from the prompt tokens - //tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; - tokens_j[j][i] = 100 + i; + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + //tokens_j[j][i] = 100 + i; } } @@ -234,6 +234,8 @@ int main(int argc, char ** argv) { if (ngrams_cur[g].active) { i_batch = ngrams_cur[g].i_batch[v]; seq_id_best = ngrams_cur[g].seq_id; + + ++n_accept; break; } } @@ -281,14 +283,13 @@ int main(int argc, char ** argv) { } else { if (id != ngrams_cur[g].tokens[v + 1]) { ngrams_cur[g].active = false; - } else { } } } } - // print known n-grams starting with token id - if (0) { + // print known n-grams starting with token id (debug) + if (0 && v == 0) { if (ngrams_observed.cnt[id] > 0) { printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); } @@ -326,8 +327,8 @@ int main(int argc, char ** argv) { } else { for (int i = 0; i < W; i++) { // random init - //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; - tokens_j[N - 2][i] = tokens_j[0][i]; + tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + //tokens_j[N - 2][i] = tokens_j[0][i]; } } } @@ -340,11 +341,38 @@ int main(int argc, char ** argv) { // n-gram generation // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 for (int f = 0; f < W; ++f) { + const int ft = tokens_j_prev[f]; // first token of the n-gram + for (int j = 0; j < N - 1; ++j) { ngram[j] = tokens_j[j][f]; - }; + } + + // filter-out repeating n-grams + { + bool is_unique = true; + + for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) { + const int idx = ft*(N - 1)*G + k*(N - 1); + + bool is_match = true; + for (int j = 0; j < N - 1; ++j) { + if (ngrams_observed.tokens[idx + j] != ngram[j]) { + is_match = false; + break; + } + } + + if (is_match) { + is_unique = false; + break; + } + } + + if (!is_unique) { + continue; + } + } - const int ft = tokens_j_prev[f]; // first token of the n-gram const int head = ngrams_observed.head[ft]; const int idx = ft*(N - 1)*G + head*(N - 1); @@ -360,6 +388,10 @@ int main(int argc, char ** argv) { } } + if (n_predict > params.n_predict || has_eos) { + break; + } + llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) {