mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-04 01:57:53 +01:00
examples : do not eval prompt 2 times (close #3348)
This commit is contained in:
parent
a207561503
commit
2b8830af71
@ -1,6 +1,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
|
|||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
ctx_params.seed = 1234;
|
ctx_params.seed = 1234;
|
||||||
ctx_params.n_ctx = 2048;
|
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
|
||||||
|
ctx_params.n_batch = std::max(n_len, n_parallel);
|
||||||
|
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||||
|
|
||||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||||
|
|
||||||
@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||||
|
|
||||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
|
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||||
|
|
||||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||||
if (n_kv_req > n_ctx) {
|
if (n_kv_req > n_ctx) {
|
||||||
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
|
LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
|
||||||
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
|
|||||||
// create a llama_batch with size 512
|
// create a llama_batch with size 512
|
||||||
// we use this object to submit token data for decoding
|
// we use this object to submit token data for decoding
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(512, 0);
|
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
batch.n_tokens = tokens_list.size();
|
batch.n_tokens = tokens_list.size();
|
||||||
@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
while (n_cur <= n_len) {
|
while (n_cur <= n_len) {
|
||||||
// evaluate the current batch with the transformer model
|
|
||||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
|
|
||||||
@ -149,8 +146,8 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(ctx);
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
|
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
|
|||||||
i_batch[i] = -1;
|
i_batch[i] = -1;
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
if (n_parallel > 1) {
|
if (n_parallel > 1) {
|
||||||
LOG_TEE("%s: stream %d finished", __func__, i);
|
LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
|
|
||||||
|
// evaluate the current batch with the transformer model
|
||||||
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
@ -110,16 +110,10 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
while (n_cur <= n_len) {
|
while (n_cur <= n_len) {
|
||||||
// evaluate the current batch with the transformer model
|
|
||||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// sample the next token
|
// sample the next token
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(ctx);
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -158,6 +152,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
|
|
||||||
|
// evaluate the current batch with the transformer model
|
||||||
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
Loading…
Reference in New Issue
Block a user