2025-01-26 00:49:07 +01:00
|
|
|
#ifdef LLAMA_USE_LLGUIDANCE
|
|
|
|
|
|
|
|
#include "common.h"
|
|
|
|
#include "sampling.h"
|
|
|
|
#include "log.h"
|
|
|
|
#include "llama.h"
|
|
|
|
|
2025-01-25 23:43:57 +01:00
|
|
|
#include "llguidance.h"
|
|
|
|
|
|
|
|
struct llama_sampler_llg {
|
2025-01-26 05:43:33 +01:00
|
|
|
const llama_model * model;
|
|
|
|
const llama_vocab * vocab;
|
2025-01-25 23:43:57 +01:00
|
|
|
std::string grammar_kind;
|
|
|
|
std::string grammar_data;
|
|
|
|
LlgTokenizer *tokenizer;
|
|
|
|
LlgConstraint *grammar;
|
|
|
|
LlgMaskResult llg_res;
|
|
|
|
bool has_llg_res;
|
|
|
|
};
|
|
|
|
|
|
|
|
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
|
|
|
|
const char * grammar_kind, const char * grammar_data) {
|
|
|
|
LlgConstraintInit cinit;
|
|
|
|
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
|
|
|
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
|
|
|
if (llg_get_error(c)) {
|
2025-01-26 00:49:07 +01:00
|
|
|
LOG_ERR("llg error: %s\n", llg_get_error(c));
|
2025-01-25 23:43:57 +01:00
|
|
|
llg_free_constraint(c);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return c;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
2025-01-25 23:43:57 +01:00
|
|
|
return "llguidance";
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
2025-01-25 23:43:57 +01:00
|
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
|
|
if (ctx->grammar) {
|
|
|
|
LlgCommitResult res;
|
|
|
|
llg_commit_token(ctx->grammar, token, &res);
|
|
|
|
ctx->has_llg_res = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
2025-01-25 23:43:57 +01:00
|
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
|
|
if (ctx->grammar) {
|
|
|
|
if (!ctx->has_llg_res) {
|
|
|
|
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
|
|
|
ctx->has_llg_res = true;
|
|
|
|
} else {
|
2025-01-26 00:49:07 +01:00
|
|
|
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
|
2025-01-25 23:43:57 +01:00
|
|
|
llg_free_constraint(ctx->grammar);
|
|
|
|
ctx->grammar = nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (ctx->has_llg_res) {
|
|
|
|
if (ctx->llg_res.is_stop) {
|
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
2025-01-26 00:49:07 +01:00
|
|
|
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
|
2025-01-25 23:43:57 +01:00
|
|
|
cur_p->data[i].logit = -INFINITY;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
const uint32_t *mask = ctx->llg_res.sample_mask;
|
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
|
auto token = cur_p->data[i].id;
|
|
|
|
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
|
|
|
cur_p->data[i].logit = -INFINITY;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
2025-01-25 23:43:57 +01:00
|
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
|
|
if (!ctx->grammar) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
|
|
|
|
llg_free_constraint(ctx->grammar);
|
|
|
|
ctx->grammar = grammar_new;
|
|
|
|
ctx->has_llg_res = false;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
2025-01-25 23:43:57 +01:00
|
|
|
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
|
|
|
|
|
|
|
auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr);
|
|
|
|
|
|
|
|
// copy the state
|
|
|
|
{
|
|
|
|
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
|
|
|
|
|
|
|
if (ctx->grammar) {
|
|
|
|
result_ctx->grammar_kind = ctx->grammar_kind;
|
|
|
|
result_ctx->grammar_data = ctx->grammar_data;
|
|
|
|
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
|
|
|
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
2025-01-25 23:43:57 +01:00
|
|
|
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
|
|
|
|
|
|
if (ctx->grammar) {
|
|
|
|
llg_free_constraint(ctx->grammar);
|
|
|
|
llg_free_tokenizer(ctx->tokenizer);
|
|
|
|
}
|
|
|
|
|
|
|
|
delete ctx;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static llama_sampler_i llama_sampler_llg_i = {
|
2025-01-25 23:43:57 +01:00
|
|
|
/* .name = */ llama_sampler_llg_name,
|
|
|
|
/* .accept = */ llama_sampler_llg_accept_impl,
|
|
|
|
/* .apply = */ llama_sampler_llg_apply,
|
|
|
|
/* .reset = */ llama_sampler_llg_reset,
|
|
|
|
/* .clone = */ llama_sampler_llg_clone,
|
|
|
|
/* .free = */ llama_sampler_llg_free,
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
|
|
|
|
const uint8_t *bytes,
|
|
|
|
size_t bytes_len,
|
|
|
|
uint32_t *output_tokens,
|
|
|
|
size_t output_tokens_len)
|
|
|
|
{
|
2025-01-26 05:43:33 +01:00
|
|
|
const llama_vocab *vocab = (const llama_vocab *)user_data;
|
2025-01-26 00:49:07 +01:00
|
|
|
int r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
|
2025-01-25 23:43:57 +01:00
|
|
|
(int32_t*)output_tokens, output_tokens_len, false, true);
|
|
|
|
if (r < 0)
|
|
|
|
return -r;
|
|
|
|
return r;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model) {
|
2025-01-25 23:43:57 +01:00
|
|
|
// TODO store the tokenizer in the model somehow
|
2025-01-26 05:43:33 +01:00
|
|
|
static const llama_model *model_cache;
|
2025-01-25 23:43:57 +01:00
|
|
|
static LlgTokenizer *tokenizer_cache;
|
|
|
|
|
|
|
|
if (model_cache == model) {
|
|
|
|
return llg_clone_tokenizer(tokenizer_cache);
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
const llama_vocab *vocab = llama_model_get_vocab(model);
|
2025-01-26 00:49:07 +01:00
|
|
|
|
|
|
|
auto tok_eos = llama_vocab_eot(vocab);
|
2025-01-25 23:43:57 +01:00
|
|
|
if (tok_eos == LLAMA_TOKEN_NULL)
|
2025-01-26 00:49:07 +01:00
|
|
|
tok_eos = llama_vocab_eos(vocab);
|
2025-01-25 23:43:57 +01:00
|
|
|
|
2025-01-26 00:49:07 +01:00
|
|
|
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
2025-01-25 23:43:57 +01:00
|
|
|
|
|
|
|
auto token_lens = new uint32_t[vocab_size];
|
|
|
|
// we typically have ~7 bytes per token; let's go on the safe side here
|
|
|
|
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
|
|
|
auto token_bytes = new uint8_t[token_bytes_size];
|
|
|
|
|
|
|
|
size_t offset = 0;
|
|
|
|
for (size_t i = 0; i < vocab_size; i++) {
|
|
|
|
size_t max_token = 1024;
|
|
|
|
if (token_bytes_size - offset < max_token) {
|
|
|
|
GGML_ABORT("token_bytes buffer too small\n");
|
|
|
|
}
|
|
|
|
|
|
|
|
llama_token token = i;
|
|
|
|
auto dp = (char *) token_bytes + offset;
|
2025-01-26 00:49:07 +01:00
|
|
|
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
2025-01-25 23:43:57 +01:00
|
|
|
if (size < 0) {
|
|
|
|
GGML_ABORT("llama_detokenize failed\n");
|
|
|
|
}
|
|
|
|
if (size == 0) {
|
2025-01-26 00:49:07 +01:00
|
|
|
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
2025-01-25 23:43:57 +01:00
|
|
|
if (size < 0) {
|
|
|
|
GGML_ABORT("llama_detokenize failed\n");
|
|
|
|
}
|
|
|
|
if (size != 0) {
|
|
|
|
*dp = '\xff'; // special token prefix marker
|
|
|
|
size += 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
token_lens[i] = size;
|
|
|
|
offset += size;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
LlgTokenizerInit tinit = {
|
|
|
|
/* .vocab_size = */ (uint32_t)vocab_size,
|
|
|
|
/* .tok_eos = */ (uint32_t)tok_eos,
|
|
|
|
/* .token_lens = */ token_lens,
|
|
|
|
/* .token_bytes = */ token_bytes,
|
|
|
|
/* .tokenizer_json = */ nullptr,
|
|
|
|
/* .tokenize_assumes_string = */ false,
|
|
|
|
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
|
|
|
/* .use_approximate_greedy_tokenize_fn = */ false,
|
2025-01-26 00:49:07 +01:00
|
|
|
/* .tokenize_user_data = */ vocab,
|
2025-01-25 23:43:57 +01:00
|
|
|
};
|
|
|
|
|
|
|
|
char error_buffer[1024];
|
|
|
|
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
|
|
|
|
|
|
|
delete[] token_bytes;
|
|
|
|
delete[] token_lens;
|
|
|
|
|
|
|
|
if (tokenizer == nullptr) {
|
2025-01-26 00:49:07 +01:00
|
|
|
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
2025-01-25 23:43:57 +01:00
|
|
|
return tokenizer;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (tokenizer_cache) {
|
|
|
|
llg_free_tokenizer(tokenizer_cache);
|
|
|
|
}
|
|
|
|
model_cache = model;
|
|
|
|
tokenizer_cache = tokenizer;
|
|
|
|
|
|
|
|
return tokenizer;
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
llama_sampler * llama_sampler_init_llg(const llama_model * model,
|
2025-01-25 23:43:57 +01:00
|
|
|
const char * grammar_kind, const char * grammar_data) {
|
|
|
|
auto * ctx = new llama_sampler_llg;
|
|
|
|
|
2025-01-26 00:49:07 +01:00
|
|
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
|
|
|
2025-01-25 23:43:57 +01:00
|
|
|
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
|
|
|
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
|
|
|
|
*ctx = {
|
|
|
|
/* .model = */ model,
|
2025-01-26 00:49:07 +01:00
|
|
|
/* .vocab = */ vocab,
|
2025-01-25 23:43:57 +01:00
|
|
|
/* .grammar_kind = */ grammar_kind,
|
|
|
|
/* .grammar_data = */ grammar_data,
|
|
|
|
/* .tokenizer = */ tokenizer,
|
|
|
|
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
|
|
|
/* .llg_res = */ {},
|
|
|
|
/* .has_llg_res = */ false,
|
|
|
|
};
|
|
|
|
} else {
|
|
|
|
*ctx = {
|
|
|
|
/* .model = */ model,
|
2025-01-26 00:49:07 +01:00
|
|
|
/* .vocab = */ vocab,
|
2025-01-25 23:43:57 +01:00
|
|
|
/* .grammar_kind = */ {},
|
|
|
|
/* .grammar_data = */ {},
|
|
|
|
/* .tokenizer = */ nullptr,
|
|
|
|
/* .grammar = */ nullptr,
|
|
|
|
/* .llg_res = */ {},
|
|
|
|
/* .has_llg_res = */ false,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
return new llama_sampler {
|
|
|
|
/* .iface = */ &llama_sampler_llg_i,
|
|
|
|
/* .ctx = */ ctx,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2025-01-26 05:43:33 +01:00
|
|
|
#endif // LLAMA_USE_LLGUIDANCE
|