mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
sampling : deduplicated code for probability distribution access (#6240)
* sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare`
This commit is contained in:
parent
ddf6568510
commit
586e7bc561
@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
bool is_resampling) { // Add a parameter to indicate if we are resampling
|
bool is_resampling) { // Add a parameter to indicate if we are resampling
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
|
||||||
const float penalty_freq = params.penalty_freq;
|
|
||||||
const float penalty_present = params.penalty_present;
|
|
||||||
const int mirostat = params.mirostat;
|
const int mirostat = params.mirostat;
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
const bool penalize_nl = params.penalize_nl;
|
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
|
||||||
auto & cur = ctx_sampling->cur;
|
|
||||||
|
|
||||||
|
std::vector<float> original_logits;
|
||||||
|
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
|
||||||
|
if (!is_resampling) {
|
||||||
|
GGML_ASSERT(!original_logits.empty());
|
||||||
|
}
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
|
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
// Declare original_logits at the beginning of the function scope
|
|
||||||
std::vector<float> original_logits;
|
|
||||||
|
|
||||||
if (!is_resampling) {
|
|
||||||
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
|
|
||||||
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply params.logit_bias map
|
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
||||||
logits[it->first] += it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx_cfg) {
|
|
||||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur.clear();
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
|
||||||
|
|
||||||
// apply penalties
|
|
||||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
|
||||||
if (penalty_tokens_used_size) {
|
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
|
||||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
|
||||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
|
||||||
cur_p.data[idx].logit = nl_logit;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we are in the resampling phase, apply grammar checks before sampling logic
|
|
||||||
if (is_resampling && ctx_sampling->grammar != NULL) {
|
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (temp < 0.0) {
|
if (temp < 0.0) {
|
||||||
// greedy sampling, with probs
|
// greedy sampling, with probs
|
||||||
llama_sample_softmax(ctx_main, &cur_p);
|
llama_sample_softmax(ctx_main, &cur_p);
|
||||||
@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_token_data_array llama_sample_probability_distribution_impl(
|
static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx,
|
||||||
|
bool apply_grammar,
|
||||||
|
std::vector<float> * original_logits) {
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
|
|||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float penalty_present = params.penalty_present;
|
const float penalty_present = params.penalty_present;
|
||||||
|
|
||||||
const bool penalize_nl = params.penalize_nl;
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
auto & prev = ctx_sampling->prev;
|
||||||
@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
|
|||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
// Declare original_logits at the beginning of the function scope
|
if (apply_grammar && original_logits != NULL) {
|
||||||
std::vector<float> original_logits;
|
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
||||||
|
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
|
||||||
|
}
|
||||||
|
|
||||||
// apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply grammar checks
|
// apply grammar checks before sampling logic
|
||||||
if (ctx_sampling->grammar != NULL) {
|
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sample_softmax(ctx_main, &cur_p);
|
|
||||||
return cur_p;
|
return cur_p;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
|
|||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array llama_sampling_probability_distribution(
|
llama_token_data_array llama_sampling_prepare(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx,
|
||||||
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
|
bool apply_grammar,
|
||||||
|
std::vector<float> * original_logits) {
|
||||||
|
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
|
@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = 0);
|
||||||
|
|
||||||
// returns the probability that token of given id will be sampled
|
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
|
||||||
llama_token_data_array llama_sampling_probability_distribution(
|
llama_token_data_array llama_sampling_prepare(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = 0,
|
||||||
|
bool apply_grammar = true,
|
||||||
|
std::vector<float> * original_logits = nullptr);
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
|
|||||||
if (params.sparams.temp > 0) {
|
if (params.sparams.temp > 0) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
|
|
||||||
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
||||||
|
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
||||||
float p_tgt = 0, p_dft = 0;
|
float p_tgt = 0, p_dft = 0;
|
||||||
|
|
||||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||||
|
Loading…
Reference in New Issue
Block a user