mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-04 01:57:53 +01:00
examples : add example for batched decoding
This commit is contained in:
parent
d008733e6b
commit
a207561503
1
.gitignore
vendored
1
.gitignore
vendored
@ -51,6 +51,7 @@ models-mnt
|
|||||||
/save-load-state
|
/save-load-state
|
||||||
/server
|
/server
|
||||||
/simple
|
/simple
|
||||||
|
/batched
|
||||||
/speculative
|
/speculative
|
||||||
/parallel
|
/parallel
|
||||||
/train-text-from-scratch
|
/train-text-from-scratch
|
||||||
|
5
Makefile
5
Makefile
@ -1,5 +1,5 @@
|
|||||||
# Define the default target now so that it is always the first target
|
# Define the default target now so that it is always the first target
|
||||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
|
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
|
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
|
||||||
@ -519,6 +519,9 @@ main: examples/main/main.cpp build-info.h ggml.
|
|||||||
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
|
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ else()
|
|||||||
add_subdirectory(train-text-from-scratch)
|
add_subdirectory(train-text-from-scratch)
|
||||||
add_subdirectory(convert-llama2c-to-ggml)
|
add_subdirectory(convert-llama2c-to-ggml)
|
||||||
add_subdirectory(simple)
|
add_subdirectory(simple)
|
||||||
|
add_subdirectory(batched)
|
||||||
add_subdirectory(speculative)
|
add_subdirectory(speculative)
|
||||||
add_subdirectory(parallel)
|
add_subdirectory(parallel)
|
||||||
add_subdirectory(embd-input)
|
add_subdirectory(embd-input)
|
||||||
|
5
examples/batched/CMakeLists.txt
Normal file
5
examples/batched/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
set(TARGET batched)
|
||||||
|
add_executable(${TARGET} batched.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
44
examples/batched/README.md
Normal file
44
examples/batched/README.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# llama.cpp/example/batched
|
||||||
|
|
||||||
|
The example demonstrates batched generation from a given prompt
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
|
||||||
|
|
||||||
|
Hello my name is
|
||||||
|
|
||||||
|
main: generating 4 sequences ...
|
||||||
|
|
||||||
|
main: stream 0 finished
|
||||||
|
main: stream 1 finished
|
||||||
|
main: stream 2 finished
|
||||||
|
main: stream 3 finished
|
||||||
|
|
||||||
|
sequence 0:
|
||||||
|
|
||||||
|
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
|
||||||
|
|
||||||
|
sequence 1:
|
||||||
|
|
||||||
|
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
|
||||||
|
|
||||||
|
sequence 2:
|
||||||
|
|
||||||
|
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
|
||||||
|
|
||||||
|
sequence 3:
|
||||||
|
|
||||||
|
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
|
||||||
|
|
||||||
|
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
|
||||||
|
|
||||||
|
llama_print_timings: load time = 587.00 ms
|
||||||
|
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
|
||||||
|
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
|
||||||
|
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
|
||||||
|
llama_print_timings: total time = 4156.04 ms
|
||||||
|
```
|
243
examples/batched/batched.cpp
Normal file
243
examples/batched/batched.cpp
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
|
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
||||||
|
return 1 ;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_parallel = 1;
|
||||||
|
|
||||||
|
if (argc >= 2) {
|
||||||
|
params.model = argv[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 3) {
|
||||||
|
params.prompt = argv[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 4) {
|
||||||
|
n_parallel = std::atoi(argv[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.prompt.empty()) {
|
||||||
|
params.prompt = "Hello my name is";
|
||||||
|
}
|
||||||
|
|
||||||
|
// total length of the sequences including the prompt
|
||||||
|
const int n_len = 32;
|
||||||
|
|
||||||
|
// init LLM
|
||||||
|
|
||||||
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
|
ctx_params.seed = 1234;
|
||||||
|
ctx_params.n_ctx = 2048;
|
||||||
|
|
||||||
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||||
|
|
||||||
|
if (model == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
|
if (ctx == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenize the prompt
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens_list;
|
||||||
|
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||||
|
|
||||||
|
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
|
||||||
|
|
||||||
|
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||||
|
if (n_kv_req > n_ctx) {
|
||||||
|
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
|
||||||
|
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// print the prompt token-by-token
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
for (auto id : tokens_list) {
|
||||||
|
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(stderr);
|
||||||
|
|
||||||
|
// create a llama_batch with size 512
|
||||||
|
// we use this object to submit token data for decoding
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(512, 0);
|
||||||
|
|
||||||
|
// evaluate the initial prompt
|
||||||
|
batch.n_tokens = tokens_list.size();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.token[i] = tokens_list[i];
|
||||||
|
batch.pos[i] = i;
|
||||||
|
batch.seq_id[i] = 0;
|
||||||
|
batch.logits[i] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign the system KV cache to all parallel sequences
|
||||||
|
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||||
|
for (int32_t i = 1; i < n_parallel; ++i) {
|
||||||
|
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_parallel > 1) {
|
||||||
|
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
||||||
|
}
|
||||||
|
|
||||||
|
// main loop
|
||||||
|
|
||||||
|
// we will store the parallel decoded sequences in this vector
|
||||||
|
std::vector<std::string> streams(n_parallel);
|
||||||
|
|
||||||
|
// remember the batch index of the last token for each parallel sequence
|
||||||
|
// we need this to determine which logits to sample from
|
||||||
|
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||||
|
|
||||||
|
int n_cur = batch.n_tokens;
|
||||||
|
int n_decode = 0;
|
||||||
|
|
||||||
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
|
while (n_cur <= n_len) {
|
||||||
|
// evaluate the current batch with the transformer model
|
||||||
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the next batch
|
||||||
|
batch.n_tokens = 0;
|
||||||
|
|
||||||
|
// sample the next token for each parallel sequence / stream
|
||||||
|
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||||
|
if (i_batch[i] < 0) {
|
||||||
|
// the stream has already finished
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
|
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
const int top_k = 40;
|
||||||
|
const float top_p = 0.9f;
|
||||||
|
const float temp = 0.4f;
|
||||||
|
|
||||||
|
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||||
|
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||||
|
llama_sample_temp (ctx, &candidates_p, temp);
|
||||||
|
|
||||||
|
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||||
|
|
||||||
|
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
|
||||||
|
// is it an end of stream? -> mark the stream as finished
|
||||||
|
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||||
|
i_batch[i] = -1;
|
||||||
|
LOG_TEE("\n");
|
||||||
|
if (n_parallel > 1) {
|
||||||
|
LOG_TEE("%s: stream %d finished", __func__, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there is only one stream, we print immediately to stdout
|
||||||
|
if (n_parallel == 1) {
|
||||||
|
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||||
|
|
||||||
|
// push this new token for next evaluation
|
||||||
|
batch.token [batch.n_tokens] = new_token_id;
|
||||||
|
batch.pos [batch.n_tokens] = n_cur;
|
||||||
|
batch.seq_id[batch.n_tokens] = i;
|
||||||
|
batch.logits[batch.n_tokens] = true;
|
||||||
|
|
||||||
|
i_batch[i] = batch.n_tokens;
|
||||||
|
|
||||||
|
batch.n_tokens += 1;
|
||||||
|
|
||||||
|
n_decode += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// all streams are finished
|
||||||
|
if (batch.n_tokens == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_cur += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("\n");
|
||||||
|
|
||||||
|
if (n_parallel > 1) {
|
||||||
|
LOG_TEE("\n");
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||||
|
LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
|
llama_print_timings(ctx);
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
llama_free(ctx);
|
||||||
|
llama_free_model(model);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
@ -1,12 +1,9 @@
|
|||||||
# llama.cpp/example/simple
|
# llama.cpp/example/simple
|
||||||
|
|
||||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt.
|
The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt.
|
||||||
The example demonstrates single-batch as well as parallel generation.
|
|
||||||
|
|
||||||
## Single-batch generation
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 1
|
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -22,46 +19,3 @@ llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms
|
|||||||
llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second)
|
llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second)
|
||||||
llama_print_timings: total time = 2891.13 ms
|
llama_print_timings: total time = 2891.13 ms
|
||||||
```
|
```
|
||||||
|
|
||||||
## Parallel generation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
|
|
||||||
|
|
||||||
Hello my name is
|
|
||||||
|
|
||||||
main: generating 4 sequences ...
|
|
||||||
|
|
||||||
main: stream 0 finished
|
|
||||||
main: stream 1 finished
|
|
||||||
main: stream 2 finished
|
|
||||||
main: stream 3 finished
|
|
||||||
|
|
||||||
sequence 0:
|
|
||||||
|
|
||||||
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
|
|
||||||
|
|
||||||
sequence 1:
|
|
||||||
|
|
||||||
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
|
|
||||||
|
|
||||||
sequence 2:
|
|
||||||
|
|
||||||
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
|
|
||||||
|
|
||||||
sequence 3:
|
|
||||||
|
|
||||||
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
|
|
||||||
|
|
||||||
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
|
|
||||||
|
|
||||||
llama_print_timings: load time = 587.00 ms
|
|
||||||
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
|
|
||||||
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
|
|
||||||
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
|
|
||||||
llama_print_timings: total time = 4156.04 ms
|
|
||||||
```
|
|
||||||
|
@ -10,12 +10,10 @@ int main(int argc, char ** argv) {
|
|||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (argc == 1 || argv[1][0] == '-') {
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
|
||||||
return 1 ;
|
return 1 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_parallel = 1;
|
|
||||||
|
|
||||||
if (argc >= 2) {
|
if (argc >= 2) {
|
||||||
params.model = argv[1];
|
params.model = argv[1];
|
||||||
}
|
}
|
||||||
@ -24,15 +22,11 @@ int main(int argc, char ** argv) {
|
|||||||
params.prompt = argv[2];
|
params.prompt = argv[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 4) {
|
|
||||||
n_parallel = std::atoi(argv[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.prompt.empty()) {
|
if (params.prompt.empty()) {
|
||||||
params.prompt = "Hello my name is";
|
params.prompt = "Hello my name is";
|
||||||
}
|
}
|
||||||
|
|
||||||
// total length of the sequences including the prompt
|
// total length of the sequence including the prompt
|
||||||
const int n_len = 32;
|
const int n_len = 32;
|
||||||
|
|
||||||
// init LLM
|
// init LLM
|
||||||
@ -64,9 +58,9 @@ int main(int argc, char ** argv) {
|
|||||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
|
||||||
|
|
||||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
|
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
|
||||||
|
|
||||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||||
if (n_kv_req > n_ctx) {
|
if (n_kv_req > n_ctx) {
|
||||||
@ -108,25 +102,8 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
|
||||||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
|
||||||
for (int32_t i = 1; i < n_parallel; ++i) {
|
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_parallel > 1) {
|
|
||||||
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
|
||||||
}
|
|
||||||
|
|
||||||
// main loop
|
// main loop
|
||||||
|
|
||||||
// we will store the parallel decoded sequences in this vector
|
|
||||||
std::vector<std::string> streams(n_parallel);
|
|
||||||
|
|
||||||
// remember the batch index of the last token for each parallel sequence
|
|
||||||
// we need this to determine which logits to sample from
|
|
||||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
|
||||||
|
|
||||||
int n_cur = batch.n_tokens;
|
int n_cur = batch.n_tokens;
|
||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
|
|
||||||
@ -139,18 +116,10 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the next batch
|
// sample the next token
|
||||||
batch.n_tokens = 0;
|
{
|
||||||
|
|
||||||
// sample the next token for each parallel sequence / stream
|
|
||||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
|
||||||
if (i_batch[i] < 0) {
|
|
||||||
// the stream has already finished
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(ctx);
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
|
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -161,68 +130,38 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
const int top_k = 40;
|
// sample the most likely token
|
||||||
const float top_p = 0.9f;
|
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
const float temp = 0.4f;
|
|
||||||
|
|
||||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
// is it an end of stream?
|
||||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
|
||||||
llama_sample_temp (ctx, &candidates_p, temp);
|
|
||||||
|
|
||||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
|
||||||
|
|
||||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
||||||
|
|
||||||
// is it an end of stream? -> mark the stream as finished
|
|
||||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||||
i_batch[i] = -1;
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
if (n_parallel > 1) {
|
|
||||||
LOG_TEE("%s: stream %d finished", __func__, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
continue;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there is only one stream, we print immediately to stdout
|
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||||
if (n_parallel == 1) {
|
fflush(stdout);
|
||||||
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
// prepare the next batch
|
||||||
|
batch.n_tokens = 0;
|
||||||
|
|
||||||
// push this new token for next evaluation
|
// push this new token for next evaluation
|
||||||
batch.token [batch.n_tokens] = new_token_id;
|
batch.token [batch.n_tokens] = new_token_id;
|
||||||
batch.pos [batch.n_tokens] = n_cur;
|
batch.pos [batch.n_tokens] = n_cur;
|
||||||
batch.seq_id[batch.n_tokens] = i;
|
batch.seq_id[batch.n_tokens] = 0;
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.logits[batch.n_tokens] = true;
|
||||||
|
|
||||||
i_batch[i] = batch.n_tokens;
|
|
||||||
|
|
||||||
batch.n_tokens += 1;
|
batch.n_tokens += 1;
|
||||||
|
|
||||||
n_decode += 1;
|
n_decode += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// all streams are finished
|
|
||||||
if (batch.n_tokens == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
if (n_parallel > 1) {
|
|
||||||
LOG_TEE("\n");
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
|
||||||
LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user