mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
examples : add passkey test (#3856)
* examples : add passkey test * passkey : better prints * passkey : select pass key pos from CLI * passkey : simplify n_past logic * make : add passkey target * passkey : add "self-extend"-like context extension (#4810) * llama : "self-extend"-like context extension * passkey : add comment * passkey : add readme
This commit is contained in:
parent
b7e7982953
commit
b0034d93ce
1
.gitignore
vendored
1
.gitignore
vendored
@ -51,6 +51,7 @@ models-mnt
|
|||||||
/lookup
|
/lookup
|
||||||
/main
|
/main
|
||||||
/metal
|
/metal
|
||||||
|
/passkey
|
||||||
/perplexity
|
/perplexity
|
||||||
/q8dot
|
/q8dot
|
||||||
/quantize
|
/quantize
|
||||||
|
5
Makefile
5
Makefile
@ -2,7 +2,7 @@
|
|||||||
BUILD_TARGETS = \
|
BUILD_TARGETS = \
|
||||||
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||||
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
||||||
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
|
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
TEST_TARGETS = \
|
TEST_TARGETS = \
|
||||||
@ -665,6 +665,9 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
|
|||||||
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
ifdef LLAMA_METAL
|
ifdef LLAMA_METAL
|
||||||
metal: examples/metal/metal.cpp ggml.o $(OBJS)
|
metal: examples/metal/metal.cpp ggml.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
||||||
|
@ -31,6 +31,7 @@ else()
|
|||||||
add_subdirectory(quantize-stats)
|
add_subdirectory(quantize-stats)
|
||||||
add_subdirectory(save-load-state)
|
add_subdirectory(save-load-state)
|
||||||
add_subdirectory(simple)
|
add_subdirectory(simple)
|
||||||
|
add_subdirectory(passkey)
|
||||||
add_subdirectory(speculative)
|
add_subdirectory(speculative)
|
||||||
add_subdirectory(lookahead)
|
add_subdirectory(lookahead)
|
||||||
add_subdirectory(lookup)
|
add_subdirectory(lookup)
|
||||||
|
@ -69,6 +69,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::vector<llama_token> tokens_list;
|
std::vector<llama_token> tokens_list;
|
||||||
tokens_list = ::llama_tokenize(model, params.prompt, true);
|
tokens_list = ::llama_tokenize(model, params.prompt, true);
|
||||||
|
|
||||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||||
|
|
||||||
// initialize the context
|
// initialize the context
|
||||||
|
5
examples/passkey/CMakeLists.txt
Normal file
5
examples/passkey/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
set(TARGET passkey)
|
||||||
|
add_executable(${TARGET} passkey.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
12
examples/passkey/README.md
Normal file
12
examples/passkey/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# llama.cpp/example/passkey
|
||||||
|
|
||||||
|
See the following PRs for more info:
|
||||||
|
|
||||||
|
- https://github.com/ggerganov/llama.cpp/pull/3856
|
||||||
|
- https://github.com/ggerganov/llama.cpp/pull/4810
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250
|
||||||
|
```
|
296
examples/passkey/passkey.cpp
Normal file
296
examples/passkey/passkey.cpp
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
|
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
|
||||||
|
return 1 ;
|
||||||
|
}
|
||||||
|
|
||||||
|
int seed = -1;
|
||||||
|
|
||||||
|
int n_junk = 250; // number of times to repeat the junk text
|
||||||
|
int n_keep = 32; // number of tokens in the prompt prefix
|
||||||
|
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
|
||||||
|
int i_pos = -1; // position of the passkey in the junk text
|
||||||
|
|
||||||
|
if (argc >= 2) {
|
||||||
|
params.model = argv[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 3) {
|
||||||
|
n_junk = std::stoi(argv[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 4) {
|
||||||
|
n_grp = std::stoi(argv[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 5) {
|
||||||
|
i_pos = std::stoi(argv[4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 6) {
|
||||||
|
seed = std::stoi(argv[5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seed == -1) {
|
||||||
|
seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
srand(seed);
|
||||||
|
|
||||||
|
if (i_pos == -1) {
|
||||||
|
i_pos = rand() % n_junk;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
|
||||||
|
const std::string prompt_suffix = " What is the pass key? The pass key is";
|
||||||
|
|
||||||
|
// generate junk text
|
||||||
|
params.prompt = prompt_prefix;
|
||||||
|
|
||||||
|
const int passkey = rand() % 50000 + 1;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_junk; i++) {
|
||||||
|
if (i % n_junk == i_pos) {
|
||||||
|
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
|
||||||
|
}
|
||||||
|
|
||||||
|
params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
|
||||||
|
}
|
||||||
|
|
||||||
|
params.prompt += prompt_suffix;
|
||||||
|
|
||||||
|
// init LLM
|
||||||
|
|
||||||
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
// initialize the model
|
||||||
|
|
||||||
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
|
model_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||||
|
|
||||||
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||||
|
|
||||||
|
if (model == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize the context
|
||||||
|
|
||||||
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
|
ctx_params.seed = seed;
|
||||||
|
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
|
||||||
|
ctx_params.n_batch = 512;
|
||||||
|
ctx_params.n_threads = params.n_threads;
|
||||||
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
|
||||||
|
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
||||||
|
|
||||||
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
|
if (ctx == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenize the prompt
|
||||||
|
std::vector<llama_token> tokens_list;
|
||||||
|
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
|
// tokenize the prefix and use it as a sink
|
||||||
|
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
|
||||||
|
|
||||||
|
const int n_tokens_all = tokens_list.size();
|
||||||
|
|
||||||
|
// we leave a margin of 16 tokens for the generated text - it should contain just the passkey
|
||||||
|
const int n_predict = 16;
|
||||||
|
|
||||||
|
// total length of the sequences including the prompt
|
||||||
|
const int n_len = n_tokens_all + n_predict;
|
||||||
|
|
||||||
|
const int n_ctx = llama_n_ctx(ctx) - n_keep;
|
||||||
|
const int n_kv_req = llama_n_ctx(ctx);
|
||||||
|
const int n_batch = ctx_params.n_batch;
|
||||||
|
const int n_batch_grp = ctx_params.n_batch/n_grp;
|
||||||
|
|
||||||
|
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
|
||||||
|
|
||||||
|
// print the prompt token-by-token
|
||||||
|
|
||||||
|
LOG_TEE("\n");
|
||||||
|
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
|
||||||
|
LOG_TEE("prompt tokens: %d\n", n_tokens_all);
|
||||||
|
//LOG_TEE("prompt: %s\n", params.prompt.c_str());
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
|
int n_past = 0;
|
||||||
|
|
||||||
|
// fill the KV cache
|
||||||
|
for (int i = 0; i < n_ctx; i += n_batch) {
|
||||||
|
if (i > 0 && n_grp > 1) {
|
||||||
|
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
|
||||||
|
const int ib = i/n_batch - 1;
|
||||||
|
const int bd = n_batch_grp*(n_grp - 1);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||||
|
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||||
|
|
||||||
|
n_past -= bd;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
|
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i + n_batch >= n_tokens_all) {
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
||||||
|
|
||||||
|
if (i + n_batch >= n_tokens_all) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
|
||||||
|
const int n_discard = n_batch;
|
||||||
|
|
||||||
|
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
|
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i + n_batch >= n_tokens_all) {
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int n_discard = n_past - n_ctx + n_predict;
|
||||||
|
|
||||||
|
if (n_discard > 0) {
|
||||||
|
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("\n");
|
||||||
|
LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
|
||||||
|
LOG_TEE("\n");
|
||||||
|
|
||||||
|
// main loop
|
||||||
|
|
||||||
|
int n_cur = n_tokens_all;
|
||||||
|
int n_decode = 0;
|
||||||
|
|
||||||
|
LOG_TEE("%s", prompt_suffix.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
|
while (n_cur <= n_len) {
|
||||||
|
// sample the next token
|
||||||
|
{
|
||||||
|
auto n_vocab = llama_n_vocab(model);
|
||||||
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
// sample the most likely token
|
||||||
|
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
|
||||||
|
// is it an end of stream?
|
||||||
|
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
|
||||||
|
LOG_TEE("\n");
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
n_decode += 1;
|
||||||
|
|
||||||
|
// prepare the next batch
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
// push this new token for next evaluation
|
||||||
|
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
n_cur += 1;
|
||||||
|
|
||||||
|
// evaluate the current batch with the transformer model
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("\n");
|
||||||
|
|
||||||
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
|
llama_print_timings(ctx);
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
llama_free(ctx);
|
||||||
|
llama_free_model(model);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
34
llama.cpp
34
llama.cpp
@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift(
|
|||||||
cache.head = new_head != cache.size ? new_head : 0;
|
cache.head = new_head != cache.size ? new_head : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_kv_cache_seq_div(
|
||||||
|
struct llama_kv_cache & cache,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d) {
|
||||||
|
if (p0 < 0) p0 = 0;
|
||||||
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||||
|
cache.has_shift = true;
|
||||||
|
|
||||||
|
{
|
||||||
|
llama_pos p_old = cache.cells[i].pos;
|
||||||
|
cache.cells[i].pos /= d;
|
||||||
|
cache.cells[i].delta += cache.cells[i].pos - p_old;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// model loading and saving
|
// model loading and saving
|
||||||
//
|
//
|
||||||
@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||||
|
if (delta == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
|
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
|
if (d == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the *maximum* size of the state
|
// Returns the *maximum* size of the state
|
||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user