mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 22:38:58 +01:00
try a differerent fix
This commit is contained in:
parent
e15c61635f
commit
32a392fe68
@ -458,23 +458,24 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define K_TOKEN_CHUNK 4
|
||||||
|
|
||||||
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
||||||
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
|
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
|
||||||
constexpr int k_token_chunk = 4;
|
|
||||||
if (eval_results.size() != eval_pairs.size()) {
|
if (eval_results.size() != eval_pairs.size()) {
|
||||||
eval_results.resize(eval_pairs.size());
|
eval_results.resize(eval_pairs.size());
|
||||||
}
|
}
|
||||||
if (eval_pairs.empty()) return;
|
if (eval_pairs.empty()) return;
|
||||||
|
|
||||||
size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
|
size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
|
||||||
|
|
||||||
std::atomic<int> counter(0);
|
std::atomic<int> counter(0);
|
||||||
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab, k_token_chunk] () {
|
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
|
||||||
float local_logprobs[k_token_chunk];
|
float local_logprobs[K_TOKEN_CHUNK];
|
||||||
while (true) {
|
while (true) {
|
||||||
size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
|
size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
|
||||||
if (first >= eval_results.size()) break;
|
if (first >= eval_results.size()) break;
|
||||||
size_t last = std::min(first + k_token_chunk, eval_results.size());
|
size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
|
||||||
for (size_t i = first; i < last; ++i) {
|
for (size_t i = first; i < last; ++i) {
|
||||||
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
|
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
@ -497,7 +498,6 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
|
|||||||
for (size_t it = 0; it < max_threads; ++it) {
|
for (size_t it = 0; it < max_threads; ++it) {
|
||||||
workers[it].join();
|
workers[it].join();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
Loading…
Reference in New Issue
Block a user