llama : remove token functions with context args in favor of model (#3720)

* added `llama_model_token_*` variants to all the `llama_token_*` functions.

* added `LLAMA_API`

* formatting

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* removed old `llama_token` functions

* changed 3 more functions to take in model

- `llama_token_get_text`
- `llama_token_get_score`
- `llama_token_get_type`

* added back docs

* fixed main.cpp

* changed token functions to use new model variants

* changed token functions to use new model variants

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Marcus Dunn 2023-10-23 12:40:03 -07:00 committed by GitHub
parent 6336701c93
commit 5be6c803fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 81 additions and 79 deletions

View File

@ -880,13 +880,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
if (params.ignore_eos) { if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY; params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
} }
{ {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1); llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx); llama_reset_timings(lctx);
@ -941,7 +941,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
} }
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) { std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_token bos_id = llama_token_bos(ctx); const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
std::string piece; std::string piece;
std::string result; std::string result;
@ -1186,7 +1186,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx)); const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");

View File

@ -147,7 +147,7 @@ llama_token llama_sampling_sample(
// apply penalties // apply penalties
if (!prev.empty()) { if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)]; const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p, llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n, prev.data() + prev.size() - penalty_last_n,
@ -155,7 +155,7 @@ llama_token llama_sampling_sample(
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit; cur_p.data[idx].logit = nl_logit;
break; break;
} }

View File

@ -236,8 +236,8 @@ int64_t get_example_targets_batch(
int64_t used_samples = 0; int64_t used_samples = 0;
ggml_set_f32(target_probs, 0.0f); ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(lctx); llama_token bos = llama_token_bos(llama_get_model(lctx));
llama_token eos = llama_token_eos(lctx); llama_token eos = llama_token_eos(llama_get_model(lctx));
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) { for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k); // printf("%s: batch %d\n", __func__, k);
@ -924,7 +924,7 @@ size_t tokenize_file(
for (llama_token token=0; token < n_vocab; ++token) { for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max( max_token_text_size = std::max(
max_token_text_size, max_token_text_size,
strlen(llama_token_get_text(lctx, token))); strlen(llama_token_get_text(llama_get_model(lctx), token)));
} }
// upper bound of context byte length. // upper bound of context byte length.

View File

@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? -> mark the stream as finished // is it an end of stream? -> mark the stream as finished
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
i_batch[i] = -1; i_batch[i] = -1;
LOG_TEE("\n"); LOG_TEE("\n");
if (n_parallel > 1) { if (n_parallel > 1) {

View File

@ -47,7 +47,7 @@ struct beam_search_callback_data {
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
// For example, eob can be flagged due to maximum token length, stop words, etc. // For example, eob can be flagged due to maximum token length, stop words, etc.
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) { static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
} }
// Function matching type llama_beam_search_callback_fn_t. // Function matching type llama_beam_search_callback_fn_t.

View File

@ -246,14 +246,14 @@ int main(int argc, char ** argv) {
if (suff_rm_leading_spc && inp_sfx[0] == space_token) { if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin()); inp_sfx.erase(inp_sfx.begin());
} }
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) { if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
} }
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx; embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx)); embd_inp.push_back(llama_token_middle(model));
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix)); LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix)); LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx)); embd_inp.push_back(llama_token_bos(model));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
} }
@ -577,10 +577,10 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode // deal with eot token in infill mode
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){ if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) { if(is_interacting && !params.interactive_first) {
// print an eot token // print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
} }
fflush(stdout); fflush(stdout);
printf("\n"); printf("\n");
@ -627,14 +627,14 @@ int main(int argc, char ** argv) {
if (suff_rm_leading_spc && inp_sfx[0] == space_token) { if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin()); inp_sfx.erase(inp_sfx.begin());
} }
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) { if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
} }
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx; embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx)); embd_inp.push_back(llama_token_middle(model));
embd.clear(); embd.clear();
embd_guidance.clear(); embd_guidance.clear();
n_remain = params.n_predict; n_remain = params.n_predict;
@ -644,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false; is_interacting = false;
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -661,7 +661,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG("adding input prefix BOS token\n"); LOG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(ctx)); embd_inp.push_back(llama_token_bos(model));
} }
std::string buffer; std::string buffer;
@ -724,7 +724,7 @@ int main(int argc, char ** argv) {
} }
// end of text token // end of text token
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) { if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
break; break;
} }
@ -736,7 +736,7 @@ int main(int argc, char ** argv) {
} }
} }
if (!params.interactive && n_remain <= 0) { if (!params.interactive && n_remain <= 0) {
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
fflush(stdout); fflush(stdout);
} }

View File

@ -933,7 +933,7 @@ struct sql_printer : public printer {
}; };
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx)); std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
int n_processed = 0; int n_processed = 0;
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
@ -946,7 +946,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
} }
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx); llama_token token = llama_token_bos(llama_get_model(ctx));
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);

View File

@ -137,7 +137,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) { inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params); int id = sample_id(ctx_llama, params);
static std::string ret; static std::string ret;
if (id == llama_token_eos(ctx_llama)) { if (id == llama_token_eos(llama_get_model(ctx_llama))) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = llama_token_to_piece(ctx_llama, id); ret = llama_token_to_piece(ctx_llama, id);

View File

@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx)); embd_inp.push_back(llama_token_bos(model));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
} }
@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -720,7 +720,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG("adding input prefix BOS token\n"); LOG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(ctx)); embd_inp.push_back(llama_token_bos(model));
} }
std::string buffer; std::string buffer;
@ -804,7 +804,7 @@ int main(int argc, char ** argv) {
} }
// end of text token // end of text token
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) { if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
LOG_TEE(" [end of text]\n"); LOG_TEE(" [end of text]\n");
break; break;
} }

View File

@ -347,7 +347,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (client.n_decoded > 2 && if (client.n_decoded > 2 &&
(id == llama_token_eos(ctx) || (id == llama_token_eos(model) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos || client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) { client.response.find('\n') != std::string::npos)) {

View File

@ -227,7 +227,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx); tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
} }
const auto batch_logits = llama_get_logits(ctx); const auto batch_logits = llama_get_logits(ctx);
@ -350,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx); tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
} }
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {

View File

@ -726,7 +726,7 @@ struct llama_server_context
if (json_value(data, "ignore_eos", false)) if (json_value(data, "ignore_eos", false))
{ {
slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY; slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
} }
const auto &logit_bias = data.find("logit_bias"); const auto &logit_bias = data.find("logit_bias");
@ -1056,7 +1056,7 @@ struct llama_server_context
slot.has_next_token = false; slot.has_next_token = false;
} }
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)) if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
{ {
slot.stopped_eos = true; slot.stopped_eos = true;
slot.has_next_token = false; slot.has_next_token = false;
@ -1130,7 +1130,7 @@ struct llama_server_context
json get_formated_generation(llama_client_slot &slot) json get_formated_generation(llama_client_slot &slot)
{ {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx)); const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json { return json {
@ -1555,11 +1555,11 @@ struct llama_server_context
suffix_tokens.erase(suffix_tokens.begin()); suffix_tokens.erase(suffix_tokens.begin());
} }
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx)); prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens; prompt_tokens = prefix_tokens;
} }
else else

View File

@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? // is it an end of stream?
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
LOG_TEE("\n"); LOG_TEE("\n");
break; break;

View File

@ -163,7 +163,7 @@ int main(int argc, char ** argv) {
printf("%s", token_str.c_str()); printf("%s", token_str.c_str());
fflush(stdout); fflush(stdout);
if (id == llama_token_eos(ctx_tgt)) { if (id == llama_token_eos(model_tgt)) {
has_eos = true; has_eos = true;
} }

View File

@ -7493,7 +7493,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} }
} }
const llama_token eos = llama_token_eos(ctx); const llama_token eos = llama_token_eos(&ctx->model);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
@ -7703,7 +7703,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos(ctx)) { if (token == llama_token_eos(&ctx->model)) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -8912,7 +8912,7 @@ struct llama_context * llama_new_context_with_model(
// build worst-case graph // build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens; int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
@ -9673,43 +9673,44 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data(); return ctx->embedding.data();
} }
const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) { const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
return ctx->model.vocab.id_to_token[token].text.c_str(); return model->vocab.id_to_token[token].text.c_str();
} }
float llama_token_get_score(const struct llama_context * ctx, llama_token token) { float llama_token_get_score(const struct llama_model * model, llama_token token) {
return ctx->model.vocab.id_to_token[token].score; return model->vocab.id_to_token[token].score;
} }
llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) { llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
return ctx->model.vocab.id_to_token[token].type; return model->vocab.id_to_token[token].type;
} }
llama_token llama_token_bos(const struct llama_context * ctx) { llama_token llama_token_bos(const struct llama_model * model) {
return ctx->model.vocab.special_bos_id; return model->vocab.special_bos_id;
} }
llama_token llama_token_eos(const struct llama_context * ctx) { llama_token llama_token_eos(const struct llama_model * model) {
return ctx->model.vocab.special_eos_id; return model->vocab.special_eos_id;
} }
llama_token llama_token_nl(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_model * model) {
return ctx->model.vocab.linefeed_id; return model->vocab.linefeed_id;
}
llama_token llama_token_prefix(const struct llama_context * ctx) {
return ctx->model.vocab.special_prefix_id;
} }
llama_token llama_token_middle(const struct llama_context * ctx) { llama_token llama_token_prefix(const struct llama_model * model) {
return ctx->model.vocab.special_middle_id; return model->vocab.special_prefix_id;
} }
llama_token llama_token_suffix(const struct llama_context * ctx) { llama_token llama_token_middle(const struct llama_model * model) {
return ctx->model.vocab.special_suffix_id; return model->vocab.special_middle_id;
} }
llama_token llama_token_eot(const struct llama_context * ctx) { llama_token llama_token_suffix(const struct llama_model * model) {
return ctx->model.vocab.special_eot_id; return model->vocab.special_suffix_id;
}
llama_token llama_token_eot(const struct llama_model * model) {
return model->vocab.special_eot_id;
} }
int llama_tokenize( int llama_tokenize(

21
llama.h
View File

@ -494,21 +494,22 @@ extern "C" {
// Vocab // Vocab
// //
LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// codellama infill tokens // codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
// //
// Tokenization // Tokenization