mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
grammars
: fix resampling logic regression (#7424)
This commit is contained in:
parent
fcf6538ba6
commit
e402de364b
@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx,
|
const int idx,
|
||||||
bool is_resampling) { // Add a parameter to indicate if we are resampling
|
bool is_resampling) {
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
|
||||||
std::vector<float> original_logits;
|
std::vector<float> original_logits;
|
||||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
|
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||||
if (!is_resampling) {
|
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||||
GGML_ASSERT(!original_logits.empty());
|
GGML_ASSERT(!original_logits.empty());
|
||||||
}
|
}
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
// Restore logits from the copy
|
// Restore logits from the copy
|
||||||
std::copy(original_logits.begin(), original_logits.end(), logits);
|
std::copy(original_logits.begin(), original_logits.end(), logits);
|
||||||
|
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
if (apply_grammar && original_logits != NULL) {
|
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||||
|
GGML_ASSERT(original_logits != NULL);
|
||||||
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
||||||
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
|
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
|
||||||
}
|
}
|
||||||
@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx) {
|
||||||
// Call the implementation function with is_resampling set to false by default
|
// Call the implementation function with is_resampling set to false by default
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array llama_sampling_prepare(
|
llama_token_data_array llama_sampling_prepare(
|
||||||
|
@ -707,7 +707,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
@ -728,7 +728,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||||
// for the prompt, we don't apply grammar rules
|
// for the prompt, we don't apply grammar rules
|
||||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user