mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 06:19:02 +01:00
llama : refactor samplers internal implementation (#9370)
This commit is contained in:
parent
2a358fb0c4
commit
19f4a7b296
@ -101,6 +101,10 @@ struct ring_buffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void push_back(const T & value) {
|
void push_back(const T & value) {
|
||||||
|
if (capacity == 0) {
|
||||||
|
throw std::runtime_error("ring buffer: capacity is zero");
|
||||||
|
}
|
||||||
|
|
||||||
if (sz == capacity) {
|
if (sz == capacity) {
|
||||||
// advance the start when buffer is full
|
// advance the start when buffer is full
|
||||||
first = (first + 1) % capacity;
|
first = (first + 1) % capacity;
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -23,16 +23,6 @@ struct llama_sampler_chain {
|
|||||||
mutable int32_t n_sample;
|
mutable int32_t n_sample;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
|
||||||
|
|
||||||
// TODO: tmp exposed until test-sampling is fixed
|
|
||||||
void llama_sampler_penalties_impl(
|
|
||||||
llama_token_data_array * cur_p,
|
|
||||||
const llama_token_cnt & token_count,
|
|
||||||
float penalty_repeat,
|
|
||||||
float penalty_freq,
|
|
||||||
float penalty_present);
|
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_grammar_impl(
|
struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
|
@ -148,15 +148,17 @@ static void test_penalties(
|
|||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_cnt token_count;
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
|
|
||||||
|
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
||||||
|
|
||||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
token_count[last_tokens[i]]++;
|
llama_sampler_accept(sampler, last_tokens[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
|
APPLY(sampler, &cur_p);
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user