mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
perplexity : faster Winogrande via batching (#5024)
* perplexity : faster Winogrande via batching ggml-ci * perplexity : remove unused function * perplexity : only tokenize selected tasks for Winogrande
This commit is contained in:
parent
57e2a7a52a
commit
8b20858e5e
@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
return {tokens, ppl, logit_history, prob_history};
|
return {tokens, ppl, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
||||||
int n_past, int n_batch, int n_vocab) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
std::vector<float> result;
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
result.reserve(tokens.size() * n_vocab);
|
|
||||||
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
llama_batch batch_view = {
|
||||||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
n_tokens,
|
||||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
batch.token + i,
|
||||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
nullptr,
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
batch.pos + i,
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
|
batch.n_seq_id + i,
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
batch.seq_id + i,
|
||||||
return {};
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
|
||||||
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto logits = llama_get_logits(ctx);
|
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
||||||
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
|
||||||
|
|
||||||
n_past += n_tokens;
|
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
||||||
@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
// determine the common prefix of the endings
|
// determine the common prefix of the endings
|
||||||
hs_cur.common_prefix = 0;
|
hs_cur.common_prefix = 0;
|
||||||
hs_cur.required_tokens = 0;
|
|
||||||
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
|
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
|
||||||
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
|
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
|
||||||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
|
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
|
||||||
@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
const int max_tasks_per_batch = params.n_parallel;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = 4*max_tasks_per_batch;
|
const int max_seq = 4*max_tasks_per_batch;
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
std::vector<float> batch_logits(n_ctx*n_vocab);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
std::vector<float> eval_results;
|
std::vector<float> eval_results;
|
||||||
std::vector<std::thread> workers(std::thread::hardware_concurrency());
|
std::vector<std::thread> workers(std::thread::hardware_concurrency());
|
||||||
|
|
||||||
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
|
||||||
|
|
||||||
llama_batch batch_view = {
|
|
||||||
n_tokens,
|
|
||||||
batch.token + i,
|
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
|
||||||
if (ret != 0) {
|
|
||||||
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
// decode all tasks [i0, i1)
|
// decode all tasks [i0, i1)
|
||||||
if (!decode_helper(ctx, batch, n_batch)) {
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -772,6 +749,13 @@ struct winogrande_entry {
|
|||||||
std::string second;
|
std::string second;
|
||||||
std::array<std::string, 2> choices;
|
std::array<std::string, 2> choices;
|
||||||
int answer;
|
int answer;
|
||||||
|
|
||||||
|
size_t i_batch;
|
||||||
|
size_t common_prefix;
|
||||||
|
size_t required_tokens;
|
||||||
|
size_t n_base1; // number of tokens for context + choice 1
|
||||||
|
size_t n_base2; // number of tokens for context + choice 2
|
||||||
|
std::vector<llama_token> seq_tokens[2];
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
|
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
|
||||||
@ -875,91 +859,137 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
data = std::move(selected);
|
data = std::move(selected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
|
||||||
|
|
||||||
// This is needed as usual for LLaMA models
|
// This is needed as usual for LLaMA models
|
||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
|
|
||||||
|
for (auto & task : data) {
|
||||||
|
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
|
||||||
|
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
|
||||||
|
|
||||||
|
task.common_prefix = 0;
|
||||||
|
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
|
||||||
|
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
task.common_prefix++;
|
||||||
|
}
|
||||||
|
|
||||||
|
task.required_tokens = task.common_prefix +
|
||||||
|
task.seq_tokens[0].size() - task.common_prefix +
|
||||||
|
task.seq_tokens[1].size() - task.common_prefix;
|
||||||
|
|
||||||
|
task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
|
||||||
|
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
|
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
|
const int max_tasks_per_batch = 128;
|
||||||
|
const int max_seq = 2*max_tasks_per_batch;
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
int n_correct = 0;
|
int n_correct = 0;
|
||||||
int n_done = 0;
|
int n_done = 0;
|
||||||
|
|
||||||
for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
|
for (size_t i0 = 0; i0 < data.size(); i0++) {
|
||||||
const auto& task = data[task_idx];
|
int n_cur = 0;
|
||||||
|
|
||||||
auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
|
size_t i1 = i0;
|
||||||
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
|
size_t i_batch = 0;
|
||||||
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
|
|
||||||
|
|
||||||
auto sentence_1st = task.first + task.choices[0] + task.second;
|
llama_batch_clear(batch);
|
||||||
auto sentence_2nd = task.first + task.choices[1] + task.second;
|
|
||||||
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
|
|
||||||
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
|
|
||||||
|
|
||||||
if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||||
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
|
const int s0 = 2*(i1 - i0);
|
||||||
|
if (s0 + 2 > max_seq) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
|
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
for (int s = 0; s < 2; ++s) {
|
||||||
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||||
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data[i1].i_batch = i_batch;
|
||||||
|
i_batch += data[i1].required_tokens;
|
||||||
|
|
||||||
|
n_cur += data[i1].required_tokens;
|
||||||
|
if (++i1 == data.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 == i1) {
|
||||||
|
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto query_1st_size = query_1st.size();
|
|
||||||
auto query_2nd_size = query_2nd.size();
|
|
||||||
|
|
||||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
|
||||||
// For Winogrande this seems to slow it down rather than speed it up.
|
|
||||||
//if (query_1st.size() < 32) query_1st.resize(32);
|
|
||||||
//if (query_2nd.size() < 32) query_2nd.resize(32);
|
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
|
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
// decode all tasks [i0, i1)
|
||||||
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
if (logits_1st.empty() || logits_2nd.empty()) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;
|
auto & task = data[i];
|
||||||
|
|
||||||
|
const bool skip_choice =
|
||||||
|
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
|
||||||
|
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
|
||||||
|
|
||||||
float score_1st = 0;
|
float score_1st = 0;
|
||||||
bool is_nan_1st = false;
|
bool is_nan_1st = false;
|
||||||
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||||
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||||
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
|
size_t li = n_base1 - 1;
|
||||||
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||||
const float prob = softmax(tok_logits)[query_1st[j+1]];
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
|
||||||
|
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
|
||||||
if (std::isnan(prob) || !prob) {
|
if (std::isnan(prob) || !prob) {
|
||||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||||
prob, j, sentence_1st.c_str(), base_context.size());
|
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
|
||||||
is_nan_1st = true;
|
is_nan_1st = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
score_1st += std::log(prob);
|
score_1st += std::log(prob);
|
||||||
}
|
}
|
||||||
score_1st /= (query_1st_size - base_1.size() - last_1st);
|
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
|
||||||
|
|
||||||
float score_2nd = 0;
|
float score_2nd = 0;
|
||||||
bool is_nan_2nd = false;
|
bool is_nan_2nd = false;
|
||||||
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||||
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||||
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
|
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
||||||
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||||
const float prob = softmax(tok_logits)[query_2nd[j+1]];
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
|
||||||
|
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
|
||||||
if (std::isnan(prob) || !prob) {
|
if (std::isnan(prob) || !prob) {
|
||||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||||
prob, j, sentence_2nd.c_str(), base_context.size());
|
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
|
||||||
is_nan_2nd = true;
|
is_nan_2nd = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
score_2nd += std::log(prob);
|
score_2nd += std::log(prob);
|
||||||
}
|
}
|
||||||
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
|
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
|
||||||
|
|
||||||
if (is_nan_1st || is_nan_2nd) {
|
if (is_nan_1st || is_nan_2nd) {
|
||||||
continue;
|
continue;
|
||||||
@ -967,10 +997,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
|
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
|
||||||
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
|
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
|
||||||
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
|
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
|
||||||
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
|
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
|
||||||
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
|
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
|
||||||
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
|
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -982,10 +1012,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
++n_done;
|
++n_done;
|
||||||
|
|
||||||
// Print the accumulated accuracy mean x 100
|
// Print the accumulated accuracy mean x 100
|
||||||
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
|
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
i0 = i1 - 1;
|
||||||
|
}
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
if (n_done < 100) return;
|
if (n_done < 100) return;
|
||||||
|
Loading…
Reference in New Issue
Block a user