examples : do not eval prompt 2 times (close #3348)

This commit is contained in:
Georgi Gerganov 2023-09-28 17:48:25 +03:00
parent a207561503
commit 2b8830af71
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 24 additions and 21 deletions

View File

@ -1,6 +1,7 @@
#include "common.h"
#include "llama.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <string>
@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
ctx_params.n_batch = std::max(n_len, n_parallel);
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx) {
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
return 1;
}
@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
// create a llama_batch with size 512
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(512, 0);
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();
@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
while (n_cur <= n_len) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
// prepare the next batch
batch.n_tokens = 0;
@ -149,8 +146,8 @@ int main(int argc, char ** argv) {
continue;
}
auto n_vocab = llama_n_vocab(ctx);
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
auto n_vocab = llama_n_vocab(ctx);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
i_batch[i] = -1;
LOG_TEE("\n");
if (n_parallel > 1) {
LOG_TEE("%s: stream %d finished", __func__, i);
LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
}
continue;
@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
LOG_TEE("\n");

View File

@ -110,16 +110,10 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
while (n_cur <= n_len) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
// sample the next token
{
auto n_vocab = llama_n_vocab(ctx);
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
auto n_vocab = llama_n_vocab(ctx);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -158,6 +152,12 @@ int main(int argc, char ** argv) {
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
LOG_TEE("\n");