mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
examples : do not eval prompt 2 times (close #3348)
This commit is contained in:
parent
a207561503
commit
2b8830af71
@ -1,6 +1,7 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = 2048;
|
||||
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
|
||||
ctx_params.n_batch = std::max(n_len, n_parallel);
|
||||
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||
|
||||
@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
|
||||
LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
|
||||
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
|
||||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0);
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_len) {
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
|
||||
@ -149,8 +146,8 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
|
||||
i_batch[i] = -1;
|
||||
LOG_TEE("\n");
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("%s: stream %d finished", __func__, i);
|
||||
LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
|
||||
}
|
||||
|
||||
continue;
|
||||
@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
@ -110,16 +110,10 @@ int main(int argc, char ** argv) {
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_len) {
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
{
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
@ -158,6 +152,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
Loading…
Reference in New Issue
Block a user