sampling : deduplicated code for probability distribution access (#6240)

* sampling: remove duplicated code for probability distribution access

* free original_logits

* fix original_logits allocation

* fixes based on review @cebtenzzre

* change function name to `llama_sampling_prepare`
This commit is contained in:
Minsoo Cheong 2024-03-24 17:54:07 +09:00 committed by GitHub
parent ddf6568510
commit 586e7bc561
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 76 deletions

View File

@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl(
bool is_resampling) { // Add a parameter to indicate if we are resampling bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float temp = params.temp; const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0; llama_token id = 0;
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p); llama_sample_softmax(ctx_main, &cur_p);
@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
return id; return id;
} }
static llama_token_data_array llama_sample_probability_distribution_impl( static llama_token_data_array llama_sampling_prepare_impl(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx,
bool apply_grammar,
std::vector<float> * original_logits) {
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
const float penalty_repeat = params.penalty_repeat; const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq; const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present; const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev; auto & prev = ctx_sampling->prev;
@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope if (apply_grammar && original_logits != NULL) {
std::vector<float> original_logits; // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
}
// apply params.logit_bias map // apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
} }
} }
// apply grammar checks // apply grammar checks before sampling logic
if (ctx_sampling->grammar != NULL) { if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
} }
llama_sample_softmax(ctx_main, &cur_p);
return cur_p; return cur_p;
} }
@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
} }
llama_token_data_array llama_sampling_probability_distribution( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx,
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); bool apply_grammar,
std::vector<float> * original_logits) {
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
} }
void llama_sampling_accept( void llama_sampling_accept(

View File

@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0); int idx = 0);
// returns the probability that token of given id will be sampled // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_probability_distribution( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0); int idx = 0,
bool apply_grammar = true,
std::vector<float> * original_logits = nullptr);
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,

View File

@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
if (params.sparams.temp > 0) { if (params.sparams.temp > 0) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0; float p_tgt = 0, p_dft = 0;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size()); // GGML_ASSERT(dist_tgt.size() == dist_dft.size());

BIN
retrieval Executable file

Binary file not shown.