mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
parent
b69a480af4
commit
f648ca2cee
6
Makefile
6
Makefile
@ -927,7 +927,6 @@ OBJ_COMMON = \
|
|||||||
common/ngram-cache.o \
|
common/ngram-cache.o \
|
||||||
common/sampling.o \
|
common/sampling.o \
|
||||||
common/train.o \
|
common/train.o \
|
||||||
common/grammar-parser.o \
|
|
||||||
common/build-info.o \
|
common/build-info.o \
|
||||||
common/json-schema-to-grammar.o
|
common/json-schema-to-grammar.o
|
||||||
|
|
||||||
@ -1167,11 +1166,6 @@ common/console.o: \
|
|||||||
common/console.h
|
common/console.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
common/grammar-parser.o: \
|
|
||||||
common/grammar-parser.cpp \
|
|
||||||
common/grammar-parser.h
|
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
|
||||||
|
|
||||||
common/json-schema-to-grammar.o: \
|
common/json-schema-to-grammar.o: \
|
||||||
common/json-schema-to-grammar.cpp \
|
common/json-schema-to-grammar.cpp \
|
||||||
common/json-schema-to-grammar.h
|
common/json-schema-to-grammar.h
|
||||||
|
@ -58,8 +58,6 @@ add_library(${TARGET} STATIC
|
|||||||
sampling.cpp
|
sampling.cpp
|
||||||
console.h
|
console.h
|
||||||
console.cpp
|
console.cpp
|
||||||
grammar-parser.h
|
|
||||||
grammar-parser.cpp
|
|
||||||
json.hpp
|
json.hpp
|
||||||
json-schema-to-grammar.cpp
|
json-schema-to-grammar.cpp
|
||||||
train.h
|
train.h
|
||||||
|
@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
bool invalid_param = false;
|
|
||||||
std::string arg;
|
|
||||||
const std::string arg_prefix = "--";
|
|
||||||
llama_sampling_params & sparams = params.sparams;
|
|
||||||
|
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
arg = argv[i];
|
const std::string arg_prefix = "--";
|
||||||
|
|
||||||
|
std::string arg = argv[i];
|
||||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool invalid_param = false;
|
||||||
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
|
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
|
||||||
throw std::invalid_argument("error: unknown argument: " + arg);
|
throw std::invalid_argument("error: unknown argument: " + arg);
|
||||||
}
|
}
|
||||||
@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
get_env("HF_TOKEN", params.hf_token);
|
get_env("HF_TOKEN", params.hf_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
if (params.escape) {
|
if (params.escape) {
|
||||||
string_process_escapes(params.prompt);
|
string_process_escapes(params.prompt);
|
||||||
string_process_escapes(params.input_prefix);
|
string_process_escapes(params.input_prefix);
|
||||||
string_process_escapes(params.input_suffix);
|
string_process_escapes(params.input_suffix);
|
||||||
string_process_escapes(sparams.cfg_negative_prompt);
|
|
||||||
for (auto & antiprompt : params.antiprompt) {
|
for (auto & antiprompt : params.antiprompt) {
|
||||||
string_process_escapes(antiprompt);
|
string_process_escapes(antiprompt);
|
||||||
}
|
}
|
||||||
@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sparams.seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
sparams.seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
|||||||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||||
const char split_delim = ',';
|
const char split_delim = ',';
|
||||||
|
|
||||||
llama_sampling_params & sparams = params.sparams;
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
if (arg == "-s" || arg == "--seed") {
|
if (arg == "-s" || arg == "--seed") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
|
|
||||||
params.seed = std::stoul(argv[i]);
|
|
||||||
sparams.seed = std::stoul(argv[i]);
|
sparams.seed = std::stoul(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
if (arg == "--samplers") {
|
if (arg == "--samplers") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
const auto sampler_names = string_split(argv[i], ';');
|
const auto sampler_names = string_split(argv[i], ';');
|
||||||
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
|
sparams.samplers = llama_sampling_types_from_names(sampler_names, true);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--sampling-seq") {
|
if (arg == "--sampling-seq") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
|
sparams.samplers = llama_sampling_types_from_chars(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--top-p") {
|
if (arg == "--top-p") {
|
||||||
@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
}
|
}
|
||||||
if (arg == "--typical") {
|
if (arg == "--typical") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
sparams.typical_p = std::stof(argv[i]);
|
sparams.typ_p = std::stof(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--repeat-last-n") {
|
if (arg == "--repeat-last-n") {
|
||||||
@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
sparams.mirostat_tau = std::stof(argv[i]);
|
sparams.mirostat_tau = std::stof(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--cfg-negative-prompt") {
|
|
||||||
CHECK_ARG
|
|
||||||
sparams.cfg_negative_prompt = argv[i];
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "--cfg-negative-prompt-file") {
|
|
||||||
CHECK_ARG
|
|
||||||
std::ifstream file(argv[i]);
|
|
||||||
if (!file) {
|
|
||||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
|
|
||||||
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
|
|
||||||
sparams.cfg_negative_prompt.pop_back();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "--cfg-scale") {
|
|
||||||
CHECK_ARG
|
|
||||||
sparams.cfg_scale = std::stof(argv[i]);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "-b" || arg == "--batch-size") {
|
if (arg == "-b" || arg == "--batch-size") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.n_batch = std::stoi(argv[i]);
|
params.n_batch = std::stoi(argv[i]);
|
||||||
@ -1353,7 +1331,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--ignore-eos") {
|
if (arg == "--ignore-eos") {
|
||||||
params.ignore_eos = true;
|
sparams.ignore_eos = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--penalize-nl") {
|
if (arg == "--penalize-nl") {
|
||||||
@ -1368,7 +1346,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
std::string value_str;
|
std::string value_str;
|
||||||
try {
|
try {
|
||||||
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
|
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
|
||||||
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
||||||
|
sparams.logit_bias.push_back({key, bias});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
throw std::exception();
|
throw std::exception();
|
||||||
@ -1715,13 +1694,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
const llama_sampling_params & sparams = params.sparams;
|
const auto & sparams = params.sparams;
|
||||||
|
|
||||||
std::string sampler_type_chars;
|
std::string sampler_type_chars;
|
||||||
std::string sampler_type_names;
|
std::string sampler_type_names;
|
||||||
for (const auto sampler_type : sparams.samplers_sequence) {
|
for (const auto & sampler : sparams.samplers) {
|
||||||
sampler_type_chars += static_cast<char>(sampler_type);
|
sampler_type_chars += llama_sampling_type_to_chr(sampler);
|
||||||
sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
|
sampler_type_names += llama_sampling_type_to_str(sampler) + ";";
|
||||||
}
|
}
|
||||||
sampler_type_names.pop_back();
|
sampler_type_names.pop_back();
|
||||||
|
|
||||||
@ -1756,7 +1735,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
|
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
|
||||||
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
|
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
|
||||||
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
|
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
|
||||||
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
|
|
||||||
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
|
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
|
||||||
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
|
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
|
||||||
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
|
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
|
||||||
@ -1836,18 +1814,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
|
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
|
||||||
|
|
||||||
options.push_back({ "sampling" });
|
options.push_back({ "sampling" });
|
||||||
|
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
|
||||||
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
|
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
|
||||||
"(default: %s)", sampler_type_names.c_str() });
|
"(default: %s)", sampler_type_names.c_str() });
|
||||||
options.push_back({ "*", " --sampling-seq SEQUENCE",
|
options.push_back({ "*", " --sampling-seq SEQUENCE",
|
||||||
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
|
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
|
||||||
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
|
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
|
||||||
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
|
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
|
||||||
options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
|
options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
|
||||||
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
|
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
|
||||||
options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
|
options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
|
||||||
options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
|
options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
|
||||||
options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
|
options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
|
||||||
options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
|
options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
|
||||||
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
|
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
|
||||||
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
|
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
|
||||||
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
|
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
|
||||||
@ -1862,11 +1841,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
|
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
|
||||||
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
|
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
|
||||||
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
|
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
|
||||||
options.push_back({ "main", " --cfg-negative-prompt PROMPT",
|
|
||||||
"negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
|
|
||||||
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
|
|
||||||
"negative prompt file to use for guidance" });
|
|
||||||
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
|
|
||||||
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
|
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
|
||||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||||
"if suffix/prefix are specified, template will be disabled\n"
|
"if suffix/prefix are specified, template will be disabled\n"
|
||||||
@ -2513,8 +2487,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||||||
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
|
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.ignore_eos) {
|
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
|
||||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||||
|
params.sparams.ignore_eos = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.warmup) {
|
if (params.warmup) {
|
||||||
@ -2543,7 +2518,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||||||
}
|
}
|
||||||
llama_kv_cache_clear(lctx);
|
llama_kv_cache_clear(lctx);
|
||||||
llama_synchronize(lctx);
|
llama_synchronize(lctx);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
iparams.model = model;
|
iparams.model = model;
|
||||||
@ -2622,7 +2597,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.n_threads = params.cpuparams.n_threads;
|
cparams.n_threads = params.cpuparams.n_threads;
|
||||||
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
||||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||||
cparams.seed = params.seed;
|
|
||||||
cparams.logits_all = params.logits_all;
|
cparams.logits_all = params.logits_all;
|
||||||
cparams.embeddings = params.embedding;
|
cparams.embeddings = params.embedding;
|
||||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||||
@ -3508,7 +3482,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
|||||||
|
|
||||||
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||||
const llama_sampling_params & sparams = params.sparams;
|
const auto & sparams = params.sparams;
|
||||||
|
|
||||||
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
||||||
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
||||||
@ -3559,8 +3533,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
|
|
||||||
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
|
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
|
||||||
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
|
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
|
||||||
yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
|
|
||||||
fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
|
|
||||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
@ -3571,10 +3543,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||||
|
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
|
||||||
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
|
|
||||||
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
|
|
||||||
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
|
|
||||||
|
|
||||||
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
|
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
|
||||||
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
|
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
|
||||||
@ -3585,11 +3554,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
|
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
|
||||||
|
|
||||||
fprintf(stream, "logit_bias:\n");
|
fprintf(stream, "logit_bias:\n");
|
||||||
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
|
for (const auto & logit_bias : sparams.logit_bias) {
|
||||||
if (ignore_eos && lb.first == logit_bias_eos->first) {
|
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
fprintf(stream, " %d: %f", lb.first, lb.second);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stream, "lora:\n");
|
fprintf(stream, "lora:\n");
|
||||||
@ -3642,7 +3608,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
|
|
||||||
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
|
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
|
||||||
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
|
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
|
||||||
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
|
||||||
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
||||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||||
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||||
@ -3656,7 +3621,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||||
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||||
}
|
}
|
||||||
|
@ -77,8 +77,6 @@ struct cpu_params {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
|
|
||||||
|
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 0; // context size
|
int32_t n_ctx = 0; // context size
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
@ -120,8 +118,7 @@ struct gpt_params {
|
|||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||||
|
|
||||||
// // sampling parameters
|
struct gpt_sampling_params sparams;
|
||||||
struct llama_sampling_params sparams;
|
|
||||||
|
|
||||||
std::string model = ""; // model path
|
std::string model = ""; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
@ -185,7 +182,6 @@ struct gpt_params {
|
|||||||
bool flash_attn = false; // flash attention
|
bool flash_attn = false; // flash attention
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool ignore_eos = false; // ignore generated EOS tokens
|
|
||||||
bool logits_all = false; // return logits for all tokens in the batch
|
bool logits_all = false; // return logits for all tokens in the batch
|
||||||
bool use_mmap = true; // use mmap for faster loads
|
bool use_mmap = true; // use mmap for faster loads
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
|
@ -1,539 +0,0 @@
|
|||||||
#include "grammar-parser.h"
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cwchar>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <exception>
|
|
||||||
|
|
||||||
namespace grammar_parser {
|
|
||||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
|
||||||
// copied from llama.cpp
|
|
||||||
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
||||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
|
||||||
uint8_t highbits = first_byte >> 4;
|
|
||||||
int len = lookup[highbits];
|
|
||||||
uint8_t mask = (1 << (8 - len)) - 1;
|
|
||||||
uint32_t value = first_byte & mask;
|
|
||||||
const char * end = src + len; // may overrun!
|
|
||||||
const char * pos = src + 1;
|
|
||||||
for ( ; pos < end && *pos; pos++) {
|
|
||||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
|
||||||
}
|
|
||||||
return std::make_pair(value, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
|
||||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
||||||
auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
|
|
||||||
return result.first->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
|
||||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
||||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
|
||||||
return next_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void add_rule(
|
|
||||||
parse_state & state,
|
|
||||||
uint32_t rule_id,
|
|
||||||
const std::vector<llama_grammar_element> & rule) {
|
|
||||||
if (state.rules.size() <= rule_id) {
|
|
||||||
state.rules.resize(rule_id + 1);
|
|
||||||
}
|
|
||||||
state.rules[rule_id] = rule;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool is_digit_char(char c) {
|
|
||||||
return '0' <= c && c <= '9';
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool is_word_char(char c) {
|
|
||||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
|
||||||
const char * pos = src;
|
|
||||||
const char * end = src + size;
|
|
||||||
uint32_t value = 0;
|
|
||||||
for ( ; pos < end && *pos; pos++) {
|
|
||||||
value <<= 4;
|
|
||||||
char c = *pos;
|
|
||||||
if ('a' <= c && c <= 'f') {
|
|
||||||
value += c - 'a' + 10;
|
|
||||||
} else if ('A' <= c && c <= 'F') {
|
|
||||||
value += c - 'A' + 10;
|
|
||||||
} else if ('0' <= c && c <= '9') {
|
|
||||||
value += c - '0';
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (pos != end) {
|
|
||||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
|
||||||
}
|
|
||||||
return std::make_pair(value, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char * parse_space(const char * src, bool newline_ok) {
|
|
||||||
const char * pos = src;
|
|
||||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
|
||||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
|
||||||
if (*pos == '#') {
|
|
||||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char * parse_name(const char * src) {
|
|
||||||
const char * pos = src;
|
|
||||||
while (is_word_char(*pos)) {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
if (pos == src) {
|
|
||||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char * parse_int(const char * src) {
|
|
||||||
const char * pos = src;
|
|
||||||
while (is_digit_char(*pos)) {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
if (pos == src) {
|
|
||||||
throw std::runtime_error(std::string("expecting integer at ") + src);
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
|
||||||
if (*src == '\\') {
|
|
||||||
switch (src[1]) {
|
|
||||||
case 'x': return parse_hex(src + 2, 2);
|
|
||||||
case 'u': return parse_hex(src + 2, 4);
|
|
||||||
case 'U': return parse_hex(src + 2, 8);
|
|
||||||
case 't': return std::make_pair('\t', src + 2);
|
|
||||||
case 'r': return std::make_pair('\r', src + 2);
|
|
||||||
case 'n': return std::make_pair('\n', src + 2);
|
|
||||||
case '\\':
|
|
||||||
case '"':
|
|
||||||
case '[':
|
|
||||||
case ']':
|
|
||||||
return std::make_pair(src[1], src + 2);
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
|
||||||
}
|
|
||||||
} else if (*src) {
|
|
||||||
return decode_utf8(src);
|
|
||||||
}
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_alternates(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
uint32_t rule_id,
|
|
||||||
bool is_nested);
|
|
||||||
|
|
||||||
static const char * parse_sequence(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
std::vector<llama_grammar_element> & out_elements,
|
|
||||||
bool is_nested) {
|
|
||||||
size_t last_sym_start = out_elements.size();
|
|
||||||
const char * pos = src;
|
|
||||||
|
|
||||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
|
||||||
|
|
||||||
if (last_sym_start == out_elements.size()) {
|
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
|
||||||
// the following rewrite rules:
|
|
||||||
// S{m,n} --> S S S (m times) S'(n-m)
|
|
||||||
// S'(x) ::= S S'(x-1) |
|
|
||||||
// (... n-m definitions of these S' rules ...)
|
|
||||||
// S'(1) ::= S |
|
|
||||||
// S{m,} --> S S S (m times) S'
|
|
||||||
// S' ::= S S' |
|
|
||||||
// S* --> S{0,}
|
|
||||||
// --> S' ::= S S' |
|
|
||||||
// S+ --> S{1,}
|
|
||||||
// --> S S'
|
|
||||||
// S' ::= S S' |
|
|
||||||
// S? --> S{0,1}
|
|
||||||
// --> S'
|
|
||||||
// S' ::= S |
|
|
||||||
|
|
||||||
std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
|
|
||||||
if (min_times == 0) {
|
|
||||||
out_elements.resize(last_sym_start);
|
|
||||||
} else {
|
|
||||||
// Repeat the previous elements (min_times - 1) times
|
|
||||||
for (int i = 1; i < min_times; i++) {
|
|
||||||
out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
|
||||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
|
||||||
|
|
||||||
std::vector<llama_grammar_element> rec_rule(previous_elements);
|
|
||||||
for (int i = 0; i < n_opt; i++) {
|
|
||||||
rec_rule.resize(previous_elements.size());
|
|
||||||
uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
|
|
||||||
if (i > 0 || max_times < 0) {
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
|
||||||
}
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
|
||||||
add_rule(state, rec_rule_id, rec_rule);
|
|
||||||
last_rec_rule_id = rec_rule_id;
|
|
||||||
}
|
|
||||||
if (n_opt > 0) {
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
while (*pos) {
|
|
||||||
if (*pos == '"') { // literal string
|
|
||||||
pos++;
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
while (*pos != '"') {
|
|
||||||
if (!*pos) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '[') { // char range(s)
|
|
||||||
pos++;
|
|
||||||
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
|
||||||
if (*pos == '^') {
|
|
||||||
pos++;
|
|
||||||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
|
||||||
}
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
while (*pos != ']') {
|
|
||||||
if (!*pos) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
enum llama_gretype type = last_sym_start < out_elements.size()
|
|
||||||
? LLAMA_GRETYPE_CHAR_ALT
|
|
||||||
: start_type;
|
|
||||||
|
|
||||||
out_elements.push_back({type, char_pair.first});
|
|
||||||
if (pos[0] == '-' && pos[1] != ']') {
|
|
||||||
if (!pos[1]) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto endchar_pair = parse_char(pos + 1);
|
|
||||||
pos = endchar_pair.second;
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
|
||||||
const char * name_end = parse_name(pos);
|
|
||||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
|
||||||
pos = parse_space(name_end, is_nested);
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
|
||||||
} else if (*pos == '(') { // grouping
|
|
||||||
// parse nested alternates into synthesized rule
|
|
||||||
pos = parse_space(pos + 1, true);
|
|
||||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
|
||||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
// output reference to synthesized rule
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
|
||||||
if (*pos != ')') {
|
|
||||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '.') { // any char
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '*') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(0, -1);
|
|
||||||
} else if (*pos == '+') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(1, -1);
|
|
||||||
} else if (*pos == '?') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(0, 1);
|
|
||||||
} else if (*pos == '{') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
|
|
||||||
if (!is_digit_char(*pos)) {
|
|
||||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
|
||||||
}
|
|
||||||
const char * int_end = parse_int(pos);
|
|
||||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
|
||||||
pos = parse_space(int_end, is_nested);
|
|
||||||
|
|
||||||
int max_times = -1;
|
|
||||||
|
|
||||||
if (*pos == '}') {
|
|
||||||
max_times = min_times;
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == ',') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
|
|
||||||
if (is_digit_char(*pos)) {
|
|
||||||
const char * int_end = parse_int(pos);
|
|
||||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
|
||||||
pos = parse_space(int_end, is_nested);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (*pos != '}') {
|
|
||||||
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
|
||||||
}
|
|
||||||
handle_repetitions(min_times, max_times);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_alternates(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
uint32_t rule_id,
|
|
||||||
bool is_nested) {
|
|
||||||
std::vector<llama_grammar_element> rule;
|
|
||||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
|
||||||
while (*pos == '|') {
|
|
||||||
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
|
||||||
pos = parse_space(pos + 1, true);
|
|
||||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
|
||||||
}
|
|
||||||
rule.push_back({LLAMA_GRETYPE_END, 0});
|
|
||||||
add_rule(state, rule_id, rule);
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char * parse_rule(parse_state & state, const char * src) {
|
|
||||||
const char * name_end = parse_name(src);
|
|
||||||
const char * pos = parse_space(name_end, false);
|
|
||||||
size_t name_len = name_end - src;
|
|
||||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
|
||||||
const std::string name(src, name_len);
|
|
||||||
|
|
||||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
|
||||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 3, true);
|
|
||||||
|
|
||||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
|
||||||
|
|
||||||
if (*pos == '\r') {
|
|
||||||
pos += pos[1] == '\n' ? 2 : 1;
|
|
||||||
} else if (*pos == '\n') {
|
|
||||||
pos++;
|
|
||||||
} else if (*pos) {
|
|
||||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
|
||||||
}
|
|
||||||
return parse_space(pos, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
parse_state parse(const char * src) {
|
|
||||||
try {
|
|
||||||
parse_state state;
|
|
||||||
const char * pos = parse_space(src, true);
|
|
||||||
while (*pos) {
|
|
||||||
pos = parse_rule(state, pos);
|
|
||||||
}
|
|
||||||
// Validate the state to ensure that all rules are defined
|
|
||||||
for (const auto & rule : state.rules) {
|
|
||||||
if (rule.empty()) {
|
|
||||||
throw std::runtime_error("Undefined rule");
|
|
||||||
}
|
|
||||||
for (const auto & elem : rule) {
|
|
||||||
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
|
||||||
// Ensure that the rule at that location exists
|
|
||||||
if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
|
|
||||||
// Get the name of the rule that is missing
|
|
||||||
for (const auto & kv : state.symbol_ids) {
|
|
||||||
if (kv.second == elem.value) {
|
|
||||||
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
|
||||||
return parse_state();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
|
||||||
if (0x20 <= c && c <= 0x7f) {
|
|
||||||
fprintf(file, "%c", static_cast<char>(c));
|
|
||||||
} else {
|
|
||||||
// cop out of encoding UTF-8
|
|
||||||
fprintf(file, "<U+%04X>", c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool is_char_element(llama_grammar_element elem) {
|
|
||||||
switch (elem.type) {
|
|
||||||
case LLAMA_GRETYPE_CHAR: return true;
|
|
||||||
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
|
||||||
default: return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
|
|
||||||
for (auto elem : rule) {
|
|
||||||
switch (elem.type) {
|
|
||||||
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
|
||||||
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
|
||||||
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
|
||||||
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
|
||||||
}
|
|
||||||
switch (elem.type) {
|
|
||||||
case LLAMA_GRETYPE_END:
|
|
||||||
case LLAMA_GRETYPE_ALT:
|
|
||||||
case LLAMA_GRETYPE_RULE_REF:
|
|
||||||
fprintf(file, "(%u) ", elem.value);
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR:
|
|
||||||
case LLAMA_GRETYPE_CHAR_NOT:
|
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
case LLAMA_GRETYPE_CHAR_ALT:
|
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
|
||||||
fprintf(file, "(\"");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
fprintf(file, "\") ");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_rule(
|
|
||||||
FILE * file,
|
|
||||||
uint32_t rule_id,
|
|
||||||
const std::vector<llama_grammar_element> & rule,
|
|
||||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
|
||||||
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
|
||||||
}
|
|
||||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
|
||||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
|
||||||
llama_grammar_element elem = rule[i];
|
|
||||||
switch (elem.type) {
|
|
||||||
case LLAMA_GRETYPE_END:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
|
||||||
std::to_string(i));
|
|
||||||
case LLAMA_GRETYPE_ALT:
|
|
||||||
fprintf(file, "| ");
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_RULE_REF:
|
|
||||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR:
|
|
||||||
fprintf(file, "[");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_NOT:
|
|
||||||
fprintf(file, "[^");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
|
||||||
std::to_string(rule_id) + "," + std::to_string(i));
|
|
||||||
}
|
|
||||||
fprintf(file, "-");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ALT:
|
|
||||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
|
||||||
std::to_string(rule_id) + "," + std::to_string(i));
|
|
||||||
}
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
|
||||||
fprintf(file, ".");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (is_char_element(elem)) {
|
|
||||||
switch (rule[i + 1].type) {
|
|
||||||
case LLAMA_GRETYPE_CHAR_ALT:
|
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
fprintf(file, "] ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_grammar(FILE * file, const parse_state & state) {
|
|
||||||
try {
|
|
||||||
std::map<uint32_t, std::string> symbol_id_names;
|
|
||||||
for (const auto & kv : state.symbol_ids) {
|
|
||||||
symbol_id_names[kv.second] = kv.first;
|
|
||||||
}
|
|
||||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
|
||||||
// fprintf(file, "%zu: ", i);
|
|
||||||
// print_rule_binary(file, state.rules[i]);
|
|
||||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
|
||||||
// fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
|
||||||
std::vector<const llama_grammar_element *> ret;
|
|
||||||
ret.reserve(rules.size());
|
|
||||||
for (const auto & rule : rules) {
|
|
||||||
ret.push_back(rule.data());
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,29 +0,0 @@
|
|||||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
|
||||||
// binary context-free grammar format specified by llama.h. Supports character
|
|
||||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
|
||||||
// arithmetic might look like:
|
|
||||||
//
|
|
||||||
// root ::= expr
|
|
||||||
// expr ::= term ([-+*/] term)*
|
|
||||||
// term ::= num | "(" space expr ")" space
|
|
||||||
// num ::= [0-9]+ space
|
|
||||||
// space ::= [ \t\n]*
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "llama.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace grammar_parser {
|
|
||||||
struct parse_state {
|
|
||||||
std::map<std::string, uint32_t> symbol_ids;
|
|
||||||
std::vector<std::vector<llama_grammar_element>> rules;
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> c_rules();
|
|
||||||
};
|
|
||||||
|
|
||||||
parse_state parse(const char * src);
|
|
||||||
void print_grammar(FILE * file, const parse_state & state);
|
|
||||||
}
|
|
@ -1,141 +1,28 @@
|
|||||||
#define LLAMA_API_INTERNAL
|
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include <random>
|
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
#include "common.h"
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
|
||||||
|
|
||||||
result->params = params;
|
std::string gpt_sampling_params::print_all() const {
|
||||||
result->grammar = nullptr;
|
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
|
||||||
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (result->parsed_grammar.rules.empty()) {
|
|
||||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
|
||||||
delete result;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that there is a "root" node.
|
|
||||||
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
|
|
||||||
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
|
||||||
delete result;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
|
||||||
|
|
||||||
struct llama_grammar * grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(),
|
|
||||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
|
||||||
if (grammar == nullptr) {
|
|
||||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
||||||
}
|
|
||||||
result->grammar = grammar;
|
|
||||||
}
|
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
|
||||||
|
|
||||||
result->n_valid = 0;
|
|
||||||
|
|
||||||
llama_sampling_set_rng_seed(result, params.seed);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
|
||||||
if (ctx->grammar != NULL) {
|
|
||||||
llama_grammar_free(ctx->grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
delete ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
|
||||||
if (ctx->grammar != NULL) {
|
|
||||||
llama_grammar_free(ctx->grammar);
|
|
||||||
ctx->grammar = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ctx->parsed_grammar.rules.empty()) {
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
|
||||||
|
|
||||||
struct llama_grammar * grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(),
|
|
||||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
|
||||||
if (grammar == nullptr) {
|
|
||||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
||||||
}
|
|
||||||
ctx->grammar = grammar;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
|
||||||
ctx->cur.clear();
|
|
||||||
ctx->n_valid = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
|
||||||
if (seed == LLAMA_DEFAULT_SEED) {
|
|
||||||
seed = std::random_device{}();
|
|
||||||
}
|
|
||||||
ctx->rng.seed(seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
|
||||||
if (dst->grammar) {
|
|
||||||
llama_grammar_free(dst->grammar);
|
|
||||||
dst->grammar = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (src->grammar) {
|
|
||||||
dst->grammar = llama_grammar_copy(src->grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst->prev = src->prev;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampling_last(llama_sampling_context * ctx) {
|
|
||||||
return ctx->prev.back();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
|
|
||||||
const int size = ctx_sampling->prev.size();
|
|
||||||
|
|
||||||
n = std::min(n, size);
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
for (int i = size - n; i < size; i++) {
|
|
||||||
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string llama_sampling_print(const llama_sampling_params & params) {
|
|
||||||
char result[1024];
|
char result[1024];
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
|
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
||||||
params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string gpt_sampling_params::print_samplers() const {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "CFG -> Penalties ";
|
||||||
if (params.mirostat == 0) {
|
if (mirostat == 0) {
|
||||||
for (auto sampler_type : params.samplers_sequence) {
|
for (const auto & sampler : samplers) {
|
||||||
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
|
const auto name = llama_sampling_type_to_str(sampler);
|
||||||
if (!sampler_type_name.empty()) {
|
if (!name.empty()) {
|
||||||
result += "-> " + sampler_type_name + " ";
|
result += "-> " + name + " ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -145,316 +32,191 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
|
||||||
switch (sampler_type) {
|
llama_sampling_params lparams = llama_sampling_default_params();
|
||||||
case llama_sampler_type::TOP_K: return "top_k";
|
|
||||||
case llama_sampler_type::TFS_Z: return "tfs_z";
|
lparams.seed = params.seed;
|
||||||
case llama_sampler_type::TYPICAL_P: return "typical_p";
|
lparams.n_prev = params.n_prev;
|
||||||
case llama_sampler_type::TOP_P: return "top_p";
|
lparams.n_probs = params.n_probs;
|
||||||
case llama_sampler_type::MIN_P: return "min_p";
|
lparams.min_keep = params.min_keep;
|
||||||
case llama_sampler_type::TEMPERATURE: return "temperature";
|
lparams.top_k = params.top_k;
|
||||||
|
lparams.top_p = params.top_p;
|
||||||
|
lparams.min_p = params.min_p;
|
||||||
|
lparams.tfs_z = params.tfs_z;
|
||||||
|
lparams.typ_p = params.typ_p;
|
||||||
|
lparams.temp = params.temp;
|
||||||
|
lparams.dynatemp_range = params.dynatemp_range;
|
||||||
|
lparams.dynatemp_exponent = params.dynatemp_exponent;
|
||||||
|
lparams.penalty_last_n = params.penalty_last_n;
|
||||||
|
lparams.penalty_repeat = params.penalty_repeat;
|
||||||
|
lparams.penalty_freq = params.penalty_freq;
|
||||||
|
lparams.penalty_present = params.penalty_present;
|
||||||
|
lparams.mirostat = params.mirostat;
|
||||||
|
lparams.mirostat_tau = params.mirostat_tau;
|
||||||
|
lparams.mirostat_eta = params.mirostat_eta;
|
||||||
|
lparams.penalize_nl = params.penalize_nl;
|
||||||
|
lparams.ignore_eos = params.ignore_eos;
|
||||||
|
|
||||||
|
lparams.n_samplers = params.samplers.size();
|
||||||
|
for (int i = 0; i < lparams.n_samplers; i++) {
|
||||||
|
lparams.samplers[i] = params.samplers[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampling * result = llama_sampling_init(model, lparams);
|
||||||
|
|
||||||
|
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
|
||||||
|
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) {
|
||||||
|
if (dst) {
|
||||||
|
llama_sampling_free(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = llama_sampling_cp(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
struct llama_context * ctx,
|
||||||
|
int idx) {
|
||||||
|
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||||
|
|
||||||
|
// first, sample the token without any grammar constraints
|
||||||
|
const llama_token id = llama_sampling_sample(smpl, nullptr);
|
||||||
|
|
||||||
|
// create an array with a single token data element for the sampled id
|
||||||
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||||
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||||
|
|
||||||
|
llama_sampling_grammar(smpl, &single_token_data_array);
|
||||||
|
|
||||||
|
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||||
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
|
if (is_valid) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the token is not valid, sample again, after applying the grammar constraints
|
||||||
|
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||||
|
|
||||||
|
llama_sampling_grammar(smpl, nullptr);
|
||||||
|
|
||||||
|
return llama_sampling_sample(smpl, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) {
|
||||||
|
n = std::min(n, llama_sampling_n_prev(smpl));
|
||||||
|
|
||||||
|
if (n <= 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
||||||
|
|
||||||
|
for (int i = n - 1; i >= 0; i--) {
|
||||||
|
const llama_token id = llama_sampling_prev(smpl, i);
|
||||||
|
|
||||||
|
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
||||||
|
|
||||||
|
result += llama_token_to_piece(ctx_main, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
char llama_sampling_type_to_chr(llama_sampler_type sampler) {
|
||||||
|
switch (sampler) {
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_K: return 'k';
|
||||||
|
case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f';
|
||||||
|
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_P: return 'p';
|
||||||
|
case LLAMA_SAMPLER_TYPE_MIN_P: return 'm';
|
||||||
|
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||||
|
default : return '?';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_sampling_type_to_str(llama_sampler_type sampler) {
|
||||||
|
switch (sampler) {
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||||
|
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
||||||
|
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||||
|
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||||
|
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
||||||
{"top_k", llama_sampler_type::TOP_K},
|
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
|
||||||
{"top_p", llama_sampler_type::TOP_P},
|
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
|
||||||
{"typical_p", llama_sampler_type::TYPICAL_P},
|
{ "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{"min_p", llama_sampler_type::MIN_P},
|
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
|
||||||
{"tfs_z", llama_sampler_type::TFS_Z},
|
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
||||||
{"temperature", llama_sampler_type::TEMPERATURE}
|
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
// make it ready for both system names and input names
|
// make it ready for both system names and input names
|
||||||
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
||||||
{"top-k", llama_sampler_type::TOP_K},
|
{ "top-k", LLAMA_SAMPLER_TYPE_TOP_K },
|
||||||
{"top-p", llama_sampler_type::TOP_P},
|
{ "top-p", LLAMA_SAMPLER_TYPE_TOP_P },
|
||||||
{"nucleus", llama_sampler_type::TOP_P},
|
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
|
||||||
{"typical-p", llama_sampler_type::TYPICAL_P},
|
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{"typical", llama_sampler_type::TYPICAL_P},
|
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{"min-p", llama_sampler_type::MIN_P},
|
{ "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{"tfs-z", llama_sampler_type::TFS_Z},
|
{ "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{"tfs", llama_sampler_type::TFS_Z},
|
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
|
||||||
{"temp", llama_sampler_type::TEMPERATURE}
|
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
||||||
|
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
|
||||||
|
{ "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_sampler_type> sampler_types;
|
std::vector<llama_sampler_type> samplers;
|
||||||
sampler_types.reserve(names.size());
|
samplers.reserve(names.size());
|
||||||
for (const auto & name : names)
|
|
||||||
{
|
for (const auto & name : names) {
|
||||||
auto sampler_item = sampler_canonical_name_map.find(name);
|
auto sampler = sampler_canonical_name_map.find(name);
|
||||||
if (sampler_item != sampler_canonical_name_map.end())
|
if (sampler != sampler_canonical_name_map.end()) {
|
||||||
{
|
samplers.push_back(sampler->second);
|
||||||
sampler_types.push_back(sampler_item->second);
|
} else {
|
||||||
}
|
if (allow_alt_names) {
|
||||||
else
|
sampler = sampler_alt_name_map.find(name);
|
||||||
{
|
if (sampler != sampler_alt_name_map.end()) {
|
||||||
if (allow_alt_names)
|
samplers.push_back(sampler->second);
|
||||||
{
|
|
||||||
sampler_item = sampler_alt_name_map.find(name);
|
|
||||||
if (sampler_item != sampler_alt_name_map.end())
|
|
||||||
{
|
|
||||||
sampler_types.push_back(sampler_item->second);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return sampler_types;
|
|
||||||
|
return samplers;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
|
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars) {
|
||||||
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
||||||
{'k', llama_sampler_type::TOP_K},
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
|
||||||
{'p', llama_sampler_type::TOP_P},
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
|
||||||
{'y', llama_sampler_type::TYPICAL_P},
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{'m', llama_sampler_type::MIN_P},
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P },
|
||||||
{'f', llama_sampler_type::TFS_Z},
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P },
|
||||||
{'t', llama_sampler_type::TEMPERATURE}
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_sampler_type> sampler_types;
|
std::vector<llama_sampler_type> samplers;
|
||||||
sampler_types.reserve(names_string.size());
|
samplers.reserve(chars.size());
|
||||||
for (const auto & c : names_string) {
|
|
||||||
const auto sampler_item = sampler_name_map.find(c);
|
for (const auto & c : chars) {
|
||||||
if (sampler_item != sampler_name_map.end()) {
|
const auto sampler = sampler_name_map.find(c);
|
||||||
sampler_types.push_back(sampler_item->second);
|
if (sampler != sampler_name_map.end()) {
|
||||||
|
samplers.push_back(sampler->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return sampler_types;
|
|
||||||
}
|
return samplers;
|
||||||
|
|
||||||
// no reasons to expose this function in header
|
|
||||||
static void sampler_queue(
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
const llama_sampling_params & params,
|
|
||||||
llama_token_data_array & cur_p,
|
|
||||||
size_t min_keep) {
|
|
||||||
const float temp = params.temp;
|
|
||||||
const float dynatemp_range = params.dynatemp_range;
|
|
||||||
const float dynatemp_exponent = params.dynatemp_exponent;
|
|
||||||
const int32_t top_k = params.top_k;
|
|
||||||
const float top_p = params.top_p;
|
|
||||||
const float min_p = params.min_p;
|
|
||||||
const float tfs_z = params.tfs_z;
|
|
||||||
const float typical_p = params.typical_p;
|
|
||||||
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
|
|
||||||
|
|
||||||
for (auto sampler_type : samplers_sequence) {
|
|
||||||
switch (sampler_type) {
|
|
||||||
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
|
||||||
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
|
||||||
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
|
||||||
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
|
||||||
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
|
||||||
case llama_sampler_type::TEMPERATURE:
|
|
||||||
if (dynatemp_range > 0) {
|
|
||||||
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
|
|
||||||
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
|
|
||||||
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
|
|
||||||
} else {
|
|
||||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default : break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_token llama_sampling_sample_impl(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
struct llama_context * ctx_cfg,
|
|
||||||
const int idx,
|
|
||||||
bool is_resampling) {
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
|
||||||
|
|
||||||
const float temp = params.temp;
|
|
||||||
const int mirostat = params.mirostat;
|
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
|
||||||
|
|
||||||
std::vector<float> original_logits;
|
|
||||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
|
||||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
|
||||||
GGML_ASSERT(!original_logits.empty());
|
|
||||||
}
|
|
||||||
llama_token id = 0;
|
|
||||||
|
|
||||||
if (temp < 0.0) {
|
|
||||||
// greedy sampling, with probs
|
|
||||||
llama_sample_softmax(ctx_main, &cur_p);
|
|
||||||
id = cur_p.data[0].id;
|
|
||||||
} else if (temp == 0.0) {
|
|
||||||
// greedy sampling, no probs
|
|
||||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
|
||||||
} else {
|
|
||||||
if (mirostat == 1) {
|
|
||||||
const int mirostat_m = 100;
|
|
||||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
||||||
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
|
||||||
} else if (mirostat == 2) {
|
|
||||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
||||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
|
||||||
} else {
|
|
||||||
// temperature sampling
|
|
||||||
size_t min_keep = std::max(1, params.min_keep);
|
|
||||||
|
|
||||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
|
||||||
|
|
||||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
|
||||||
|
|
||||||
//{
|
|
||||||
// const int n_top = 10;
|
|
||||||
// LOG("top %d candidates:\n", n_top);
|
|
||||||
|
|
||||||
// for (int i = 0; i < n_top; i++) {
|
|
||||||
// const llama_token id = cur_p.data[i].id;
|
|
||||||
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
|
||||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
|
||||||
// Get a pointer to the logits
|
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
||||||
|
|
||||||
// Create an array with a single token data element for the sampled id
|
|
||||||
llama_token_data single_token_data = {id, logits[id], 0.0f};
|
|
||||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
|
||||||
|
|
||||||
// Apply grammar constraints to the single token
|
|
||||||
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
|
|
||||||
|
|
||||||
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
|
||||||
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
||||||
|
|
||||||
// If the token is not valid according to the grammar, perform resampling
|
|
||||||
if (!is_valid) {
|
|
||||||
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
|
||||||
|
|
||||||
// Restore logits from the copy
|
|
||||||
std::copy(original_logits.begin(), original_logits.end(), logits);
|
|
||||||
|
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
|
|
||||||
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_token_data_array llama_sampling_prepare_impl(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
struct llama_context * ctx_cfg,
|
|
||||||
const int idx,
|
|
||||||
bool apply_grammar,
|
|
||||||
std::vector<float> * original_logits) {
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
||||||
|
|
||||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
|
||||||
const float penalty_freq = params.penalty_freq;
|
|
||||||
const float penalty_present = params.penalty_present;
|
|
||||||
|
|
||||||
const bool penalize_nl = params.penalize_nl;
|
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
|
||||||
auto & cur = ctx_sampling->cur;
|
|
||||||
|
|
||||||
// Get a pointer to the logits
|
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
|
||||||
GGML_ASSERT(original_logits != NULL);
|
|
||||||
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
|
||||||
*original_logits = {logits, logits + n_vocab};
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply params.logit_bias map
|
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
||||||
logits[it->first] += it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx_cfg) {
|
|
||||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur.resize(n_vocab);
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
|
||||||
|
|
||||||
// apply penalties
|
|
||||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
|
||||||
if (penalty_tokens_used_size) {
|
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
|
||||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
|
||||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
|
||||||
cur_p.data[idx].logit = nl_logit;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply grammar checks before sampling logic
|
|
||||||
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
|
||||||
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cur_p;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
struct llama_context * ctx_cfg,
|
|
||||||
const int idx) {
|
|
||||||
// Call the implementation function with is_resampling set to false by default
|
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array llama_sampling_prepare(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
struct llama_context * ctx_cfg,
|
|
||||||
const int idx,
|
|
||||||
bool apply_grammar,
|
|
||||||
std::vector<float> * original_logits) {
|
|
||||||
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sampling_accept(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
llama_token id,
|
|
||||||
bool apply_grammar) {
|
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
|
||||||
ctx_sampling->prev.push_back(id);
|
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
|
||||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -2,25 +2,13 @@
|
|||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// sampler types
|
|
||||||
enum class llama_sampler_type : char {
|
|
||||||
TOP_K = 'k',
|
|
||||||
TOP_P = 'p',
|
|
||||||
MIN_P = 'm',
|
|
||||||
TFS_Z = 'f',
|
|
||||||
TYPICAL_P = 'y',
|
|
||||||
TEMPERATURE = 't'
|
|
||||||
};
|
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct gpt_sampling_params {
|
||||||
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling
|
||||||
|
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
@ -28,7 +16,7 @@ typedef struct llama_sampling_params {
|
|||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typical_p = 1.00f; // 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
@ -40,121 +28,52 @@ typedef struct llama_sampling_params {
|
|||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
bool ignore_eos = false;
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<enum llama_sampler_type> samplers = {
|
||||||
llama_sampler_type::TOP_K,
|
LLAMA_SAMPLER_TYPE_TOP_K,
|
||||||
llama_sampler_type::TFS_Z,
|
LLAMA_SAMPLER_TYPE_TFS_Z,
|
||||||
llama_sampler_type::TYPICAL_P,
|
LLAMA_SAMPLER_TYPE_TYPICAL_P,
|
||||||
llama_sampler_type::TOP_P,
|
LLAMA_SAMPLER_TYPE_TOP_P,
|
||||||
llama_sampler_type::MIN_P,
|
LLAMA_SAMPLER_TYPE_MIN_P,
|
||||||
llama_sampler_type::TEMPERATURE
|
LLAMA_SAMPLER_TYPE_TEMPERATURE
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
|
||||||
// Classifier-Free Guidance
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
// https://arxiv.org/abs/2306.17806
|
|
||||||
std::string cfg_negative_prompt; // string to help guidance
|
|
||||||
float cfg_scale = 1.f; // how strong is guidance
|
|
||||||
|
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
// print the parameters into a string
|
||||||
|
std::string print_all() const;
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
// print the samplers into a string
|
||||||
bool use_penalty_prompt_tokens = false;
|
std::string print_samplers() const;
|
||||||
} llama_sampling_params;
|
} gpt_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// overload of llama_sampling_init using gpt_sampling_params
|
||||||
// TODO: move to llama.h
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
|
||||||
struct llama_sampling_context {
|
|
||||||
// parameters that will be used for sampling
|
|
||||||
llama_sampling_params params;
|
|
||||||
|
|
||||||
// mirostat sampler state
|
void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst);
|
||||||
float mirostat_mu;
|
|
||||||
|
|
||||||
llama_grammar * grammar;
|
// common sampling implementation:
|
||||||
|
|
||||||
// internal
|
|
||||||
grammar_parser::parse_state parsed_grammar;
|
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
|
||||||
std::vector<llama_token> prev;
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
|
||||||
|
|
||||||
std::mt19937 rng;
|
|
||||||
};
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
// Create a new sampling context instance.
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
|
||||||
|
|
||||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
|
||||||
|
|
||||||
// Reset the sampler context
|
|
||||||
// - clear prev tokens
|
|
||||||
// - reset grammar
|
|
||||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
|
||||||
|
|
||||||
// Set the sampler seed
|
|
||||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
|
|
||||||
|
|
||||||
// Copy the sampler context
|
|
||||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
|
||||||
|
|
||||||
// Get the last sampled token
|
|
||||||
llama_token llama_sampling_last(llama_sampling_context * ctx);
|
|
||||||
|
|
||||||
// Get a string representation of the last sampled tokens
|
|
||||||
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
|
|
||||||
|
|
||||||
// Print sampling parameters into a string
|
|
||||||
std::string llama_sampling_print(const llama_sampling_params & params);
|
|
||||||
|
|
||||||
// Print sampling order into a string
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
|
||||||
|
|
||||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
|
|
||||||
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
|
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
|
||||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
|
||||||
// llama_sampling_reset when a sequence ends
|
|
||||||
//
|
//
|
||||||
// required:
|
// - set logits
|
||||||
// - ctx_main: context to use for sampling
|
// - apply the configured sampling constraints
|
||||||
// - ctx_sampling: sampling-specific context
|
// - check if the token fits the grammar (if any)
|
||||||
//
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
// optional:
|
|
||||||
// - ctx_cfg: context to use for classifier-free guidance
|
|
||||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
|
||||||
//
|
|
||||||
// returns:
|
|
||||||
// - token: sampled token
|
|
||||||
// - candidates: vector of candidate tokens
|
|
||||||
//
|
//
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling * smpl,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx,
|
||||||
struct llama_context * ctx_cfg,
|
int idx);
|
||||||
int idx = -1);
|
|
||||||
|
|
||||||
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
|
// helpers
|
||||||
llama_token_data_array llama_sampling_prepare(
|
|
||||||
struct llama_sampling_context * ctx_sampling,
|
|
||||||
struct llama_context * ctx_main,
|
|
||||||
struct llama_context * ctx_cfg,
|
|
||||||
int idx = 0,
|
|
||||||
bool apply_grammar = true,
|
|
||||||
std::vector<float> * original_logits = nullptr);
|
|
||||||
|
|
||||||
void llama_sampling_accept(
|
// get a string representation of the last accepted tokens
|
||||||
struct llama_sampling_context * ctx_sampling,
|
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);
|
||||||
struct llama_context * ctx_main,
|
|
||||||
llama_token id,
|
char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type);
|
||||||
bool apply_grammar);
|
std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type);
|
||||||
|
|
||||||
|
std::vector<enum llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
|
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars);
|
||||||
|
@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
|
|||||||
print("Failed to load model")
|
print("Failed to load model")
|
||||||
exit(1)
|
exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
llama_free_model(model)
|
llama_free_model(model)
|
||||||
}
|
}
|
||||||
@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true)
|
|||||||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
||||||
|
|
||||||
var context_params = llama_context_default_params()
|
var context_params = llama_context_default_params()
|
||||||
context_params.seed = 1234
|
|
||||||
context_params.n_ctx = n_kv_req
|
context_params.n_ctx = n_kv_req
|
||||||
context_params.n_batch = UInt32(max(n_len, n_parallel))
|
context_params.n_batch = UInt32(max(n_len, n_parallel))
|
||||||
context_params.n_threads = 8
|
context_params.n_threads = 8
|
||||||
@ -48,11 +46,24 @@ guard context != nil else {
|
|||||||
print("Failed to initialize context")
|
print("Failed to initialize context")
|
||||||
exit(1)
|
exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
llama_free(context)
|
llama_free(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sparams = llama_sampling_params()
|
||||||
|
sparams.top_k = 40
|
||||||
|
sparams.top_p = 0.9
|
||||||
|
sparams.temp = 0.4
|
||||||
|
|
||||||
|
let smpl = llama_sampling_init(model, sparams)
|
||||||
|
guard smpl != nil else {
|
||||||
|
print("Failed to initialize sampling")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
defer {
|
||||||
|
llama_sampling_free(smpl)
|
||||||
|
}
|
||||||
|
|
||||||
let n_ctx = llama_n_ctx(context)
|
let n_ctx = llama_n_ctx(context)
|
||||||
|
|
||||||
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
|
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
|
||||||
@ -125,32 +136,17 @@ while n_cur <= n_len {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var n_vocab = llama_n_vocab(model)
|
|
||||||
var logits = llama_get_logits_ith(context, i_batch[i])
|
var logits = llama_get_logits_ith(context, i_batch[i])
|
||||||
|
|
||||||
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
|
llama_sampling_set_logits(smpl, logits)
|
||||||
|
|
||||||
for token_id in 0 ..< n_vocab {
|
llama_sampling_top_k(smpl, nil)
|
||||||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
|
llama_sampling_top_p(smpl, nil)
|
||||||
}
|
llama_sampling_temp (smpl, nil)
|
||||||
|
|
||||||
var candidates_p: llama_token_data_array = .init(
|
let new_token_id = llama_sampling_sample_dist(smpl, nil)
|
||||||
data: &candidates,
|
|
||||||
size: candidates.count,
|
|
||||||
sorted: false
|
|
||||||
)
|
|
||||||
|
|
||||||
let top_k: Int32 = 40
|
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);
|
||||||
let top_p: Float = 0.9
|
|
||||||
let temp: Float = 0.4
|
|
||||||
|
|
||||||
llama_sample_top_k(context, &candidates_p, top_k, 1)
|
|
||||||
llama_sample_top_p(context, &candidates_p, top_p, 1)
|
|
||||||
llama_sample_temp(context, &candidates_p, temp)
|
|
||||||
|
|
||||||
let new_token_id = llama_sample_token(context, &candidates_p)
|
|
||||||
|
|
||||||
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
||||||
|
|
||||||
// is it an end of stream? -> mark the stream as finished
|
// is it an end of stream? -> mark the stream as finished
|
||||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||||
@ -212,7 +208,7 @@ let t_main_end = ggml_time_us()
|
|||||||
|
|
||||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
||||||
|
|
||||||
llama_print_timings(context)
|
llama_print_timings(context, smpl)
|
||||||
|
|
||||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||||
let utf8Count = text.utf8.count
|
let utf8Count = text.utf8.count
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -65,6 +64,15 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
|
auto sparams = llama_sampling_default_params();
|
||||||
|
|
||||||
|
sparams.seed = params.sparams.seed;
|
||||||
|
sparams.top_k = 40;
|
||||||
|
sparams.top_p = 0.9f;
|
||||||
|
sparams.temp = 0.4f;
|
||||||
|
|
||||||
|
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
return 1;
|
return 1;
|
||||||
@ -164,29 +172,17 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
const auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
llama_sampling_top_k(smpl, nullptr);
|
||||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
llama_sampling_top_p(smpl, nullptr);
|
||||||
}
|
llama_sampling_temp (smpl, nullptr);
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr);
|
||||||
|
|
||||||
const int top_k = 40;
|
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||||
const float top_p = 0.9f;
|
|
||||||
const float temp = 0.4f;
|
|
||||||
|
|
||||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
|
||||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
|
||||||
llama_sample_temp (ctx, &candidates_p, temp);
|
|
||||||
|
|
||||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
|
||||||
|
|
||||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
||||||
|
|
||||||
// is it an end of generation? -> mark the stream as finished
|
// is it an end of generation? -> mark the stream as finished
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||||
@ -244,12 +240,13 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, smpl);
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
llama_sampling_free(smpl);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -90,13 +90,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
print_build_info();
|
print_build_info();
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
@ -314,7 +308,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -151,8 +151,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
print_build_info();
|
print_build_info();
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
@ -183,7 +181,7 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
#define LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#include "grammar-parser.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "llama.h"
|
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
@ -12,22 +8,21 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||||
auto decoded = decode_utf8(input_str, {});
|
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||||
const auto & code_points = decoded.first;
|
|
||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
||||||
|
|
||||||
if (cur_stacks.empty()) {
|
if (cur_stacks.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||||
cur_stacks = prev_stacks;
|
cur_stacks = prev_stacks;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -85,27 +80,7 @@ int main(int argc, char** argv) {
|
|||||||
grammar_str = buffer.str();
|
grammar_str = buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the GBNF grammar
|
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
|
||||||
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (parsed_grammar.rules.empty()) {
|
|
||||||
fprintf(stdout, "%s: failed to parse grammar\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that there is a "root" node.
|
|
||||||
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
|
|
||||||
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
|
|
||||||
// Create the LLAMA grammar
|
|
||||||
auto grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(),
|
|
||||||
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
if (grammar == nullptr) {
|
if (grammar == nullptr) {
|
||||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
}
|
}
|
||||||
@ -122,7 +97,7 @@ int main(int argc, char** argv) {
|
|||||||
// Validate the input string against the grammar
|
// Validate the input string against the grammar
|
||||||
size_t error_pos;
|
size_t error_pos;
|
||||||
std::string error_msg;
|
std::string error_msg;
|
||||||
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
|
bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
|
||||||
|
|
||||||
if (is_valid) {
|
if (is_valid) {
|
||||||
fprintf(stdout, "Input string is valid according to the grammar.\n");
|
fprintf(stdout, "Input string is valid according to the grammar.\n");
|
||||||
@ -131,7 +106,7 @@ int main(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clean up
|
// Clean up
|
||||||
llama_grammar_free(grammar);
|
llama_grammar_free_impl(grammar);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
||||||
std::vector<std::vector<float>> result;
|
std::vector<std::vector<float>> result;
|
||||||
|
|
||||||
const llama_model * mdl = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||||
|
|
||||||
@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||||||
|
|
||||||
const std::string input_string = instruction + sentences[i];
|
const std::string input_string = instruction + sentences[i];
|
||||||
|
|
||||||
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
|
std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
|
||||||
|
|
||||||
const int32_t n_toks = inputs.size();
|
const int32_t n_toks = inputs.size();
|
||||||
|
|
||||||
// GritLM seems to have EOS = ""
|
// GritLM seems to have EOS = ""
|
||||||
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
|
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
|
||||||
// inputs.push_back(llama_token_eos(mdl));
|
// inputs.push_back(llama_token_eos(model));
|
||||||
|
|
||||||
// we want to ignore instruction tokens for mean pooling
|
// we want to ignore instruction tokens for mean pooling
|
||||||
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
|
const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
|
||||||
|
|
||||||
#ifdef GRIT_DEBUG
|
#ifdef GRIT_DEBUG
|
||||||
// debug tokens - should be matching as referenced in the GritLM sample
|
// debug tokens - should be matching as referenced in the GritLM sample
|
||||||
@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||||||
llama_decode(ctx, batch);
|
llama_decode(ctx, batch);
|
||||||
|
|
||||||
// get embedding dimensions
|
// get embedding dimensions
|
||||||
uint64_t n_embd = llama_n_embd(mdl);
|
uint64_t n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
// allocate embedding output
|
// allocate embedding output
|
||||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||||
@ -92,11 +92,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
|
static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
const llama_model * mdl = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
llama_token eos_token = llama_token_eos(mdl);
|
llama_token eos_token = llama_token_eos(model);
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
llama_set_embeddings(ctx, false);
|
llama_set_embeddings(ctx, false);
|
||||||
@ -104,28 +104,27 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||||||
|
|
||||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||||
|
|
||||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
|
||||||
int32_t i_current_token = 0;
|
int32_t i_current_token = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
llama_batch_clear(bat);
|
llama_batch_clear(bat);
|
||||||
auto n_inputs = (int32_t)inputs.size();
|
{
|
||||||
|
const int32_t n_inputs = inputs.size();
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_inputs; i++) {
|
for (int32_t i = 0; i < n_inputs; i++) {
|
||||||
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
|
|
||||||
llama_decode(ctx, bat);
|
llama_decode(ctx, bat);
|
||||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
|
||||||
|
|
||||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||||
auto n_candidates = (int32_t)candidates.size();
|
|
||||||
for (int32_t token = 0; token < n_candidates; token++) {
|
|
||||||
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
|
|
||||||
}
|
|
||||||
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
|
llama_sampling_set_logits(smpl, logits);
|
||||||
|
|
||||||
|
llama_token token = llama_sampling_sample_greedy(smpl, nullptr);
|
||||||
if (token == eos_token) {
|
if (token == eos_token) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -167,10 +166,12 @@ int main(int argc, char * argv[]) {
|
|||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
|
||||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||||
|
|
||||||
// create generation context
|
// create generation context
|
||||||
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
||||||
|
|
||||||
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// ### Embedding/Representation ###
|
// ### Embedding/Representation ###
|
||||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||||
@ -191,7 +192,7 @@ int main(int argc, char * argv[]) {
|
|||||||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(mdl);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||||
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||||
@ -208,11 +209,12 @@ int main(int argc, char * argv[]) {
|
|||||||
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
||||||
{
|
{
|
||||||
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
||||||
std::string response = generate(ctx, prompt, true);
|
std::string response = generate(ctx, smpl, prompt, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_sampling_free(smpl);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(mdl);
|
llama_free_model(model);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -638,7 +638,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
g_collector.save_imatrix();
|
g_collector.save_imatrix();
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "console.h"
|
#include "console.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@ -34,6 +33,7 @@
|
|||||||
|
|
||||||
static llama_context ** g_ctx;
|
static llama_context ** g_ctx;
|
||||||
static llama_model ** g_model;
|
static llama_model ** g_model;
|
||||||
|
static llama_sampling ** g_smpl;
|
||||||
static gpt_params * g_params;
|
static gpt_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
|
|||||||
} else {
|
} else {
|
||||||
console::cleanup();
|
console::cleanup();
|
||||||
printf("\n");
|
printf("\n");
|
||||||
llama_print_timings(*g_ctx);
|
llama_print_timings(*g_ctx, *g_smpl);
|
||||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||||
_exit(130);
|
_exit(130);
|
||||||
}
|
}
|
||||||
@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
|
|||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
llama_sampling_params & sparams = params.sparams;
|
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("infill", "log"));
|
log_set_target(log_filename_generator("infill", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
@ -156,26 +157,21 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
print_build_info();
|
||||||
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
|
|
||||||
LOG("%s: llama backend init\n", __func__);
|
LOG("%s: llama backend init\n", __func__);
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
llama_model * model;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx;
|
llama_context * ctx = nullptr;
|
||||||
|
llama_sampling * smpl = nullptr;
|
||||||
|
|
||||||
g_model = &model;
|
g_model = &model;
|
||||||
g_ctx = &ctx;
|
g_ctx = &ctx;
|
||||||
|
g_smpl = &smpl;
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
@ -305,7 +301,7 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
@ -349,7 +345,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
smpl = llama_sampling_init(model, sparams);
|
||||||
|
|
||||||
while (n_remain != 0 || params.interactive) {
|
while (n_remain != 0 || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
@ -421,11 +417,11 @@ int main(int argc, char ** argv) {
|
|||||||
embd.clear();
|
embd.clear();
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
|
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
@ -444,7 +440,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||||
// for the prompt, we don't apply grammar rules
|
// for the prompt, we don't apply grammar rules
|
||||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
llama_sampling_accept(smpl, embd_inp[n_consumed], false);
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
@ -476,7 +472,7 @@ int main(int argc, char ** argv) {
|
|||||||
// if not currently processing queued inputs;
|
// if not currently processing queued inputs;
|
||||||
if ((int) embd_inp.size() <= n_consumed) {
|
if ((int) embd_inp.size() <= n_consumed) {
|
||||||
// deal with eot token in infill mode
|
// deal with eot token in infill mode
|
||||||
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
|
if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||||
if (is_interacting && !params.interactive_first) {
|
if (is_interacting && !params.interactive_first) {
|
||||||
// print an eot token
|
// print an eot token
|
||||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
||||||
@ -542,7 +538,7 @@ int main(int argc, char ** argv) {
|
|||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
// deal with end of generation tokens in interactive mode
|
// deal with end of generation tokens in interactive mode
|
||||||
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
else if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
|
||||||
LOG("found EOS token\n");
|
LOG("found EOS token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -615,7 +611,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
llama_sampling_reset(ctx_sampling);
|
llama_sampling_reset(smpl);
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
@ -638,13 +634,13 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, smpl);
|
||||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
@ -1574,7 +1574,7 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(p_err->fout);
|
fflush(p_err->fout);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
|
|||||||
LOGi("Using %d threads", n_threads);
|
LOGi("Using %d threads", n_threads);
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.seed = 1234;
|
|
||||||
ctx_params.n_ctx = 2048;
|
ctx_params.n_ctx = 2048;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
JNIEnv * env,
|
JNIEnv * env,
|
||||||
jobject,
|
jobject,
|
||||||
jlong context_pointer,
|
jlong context_pointer,
|
||||||
|
jlong sampling_pointer,
|
||||||
jlong batch_pointer,
|
jlong batch_pointer,
|
||||||
jint n_len,
|
jint n_len,
|
||||||
jobject intvar_ncur
|
jobject intvar_ncur
|
||||||
) {
|
) {
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
|
const auto sampling = reinterpret_cast<llama_sampling *>(sampling_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
const auto model = llama_get_model(context);
|
const auto model = llama_get_model(context);
|
||||||
|
|
||||||
@ -392,20 +394,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
||||||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
||||||
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(sampling, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
// sample the most likely token
|
// sample the most likely token
|
||||||
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
|
const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr);
|
||||||
|
|
||||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||||
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||||||
actor LlamaContext {
|
actor LlamaContext {
|
||||||
private var model: OpaquePointer
|
private var model: OpaquePointer
|
||||||
private var context: OpaquePointer
|
private var context: OpaquePointer
|
||||||
|
private var sampling: OpaquePointer
|
||||||
private var batch: llama_batch
|
private var batch: llama_batch
|
||||||
private var tokens_list: [llama_token]
|
private var tokens_list: [llama_token]
|
||||||
var is_done: Bool = false
|
var is_done: Bool = false
|
||||||
@ -42,9 +43,11 @@ actor LlamaContext {
|
|||||||
self.tokens_list = []
|
self.tokens_list = []
|
||||||
self.batch = llama_batch_init(512, 0, 1)
|
self.batch = llama_batch_init(512, 0, 1)
|
||||||
self.temporary_invalid_cchars = []
|
self.temporary_invalid_cchars = []
|
||||||
|
self.sampling = llama_sampling_init(context, llama_sampling_default_params())
|
||||||
}
|
}
|
||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
|
llama_sampling_free(sampling)
|
||||||
llama_batch_free(batch)
|
llama_batch_free(batch)
|
||||||
llama_free(context)
|
llama_free(context)
|
||||||
llama_free_model(model)
|
llama_free_model(model)
|
||||||
@ -69,7 +72,6 @@ actor LlamaContext {
|
|||||||
print("Using \(n_threads) threads")
|
print("Using \(n_threads) threads")
|
||||||
|
|
||||||
var ctx_params = llama_context_default_params()
|
var ctx_params = llama_context_default_params()
|
||||||
ctx_params.seed = 1234
|
|
||||||
ctx_params.n_ctx = 2048
|
ctx_params.n_ctx = 2048
|
||||||
ctx_params.n_threads = Int32(n_threads)
|
ctx_params.n_threads = Int32(n_threads)
|
||||||
ctx_params.n_threads_batch = Int32(n_threads)
|
ctx_params.n_threads_batch = Int32(n_threads)
|
||||||
@ -147,17 +149,9 @@ actor LlamaContext {
|
|||||||
let n_vocab = llama_n_vocab(model)
|
let n_vocab = llama_n_vocab(model)
|
||||||
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
|
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
|
||||||
|
|
||||||
var candidates = Array<llama_token_data>()
|
llama_sampling_set_logits(sampling, logits);
|
||||||
candidates.reserveCapacity(Int(n_vocab))
|
|
||||||
|
|
||||||
for token_id in 0..<n_vocab {
|
new_token_id = llama_sampling_sample_greedy(sampling, nil)
|
||||||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
|
|
||||||
}
|
|
||||||
candidates.withUnsafeMutableBufferPointer() { buffer in
|
|
||||||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
|
||||||
|
|
||||||
new_token_id = llama_sample_token_greedy(context, &candidates_p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||||
print("\n")
|
print("\n")
|
||||||
|
@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
static const char * sample(struct llama_sampling * smpl,
|
||||||
struct llama_context * ctx_llama,
|
struct llama_context * ctx_llama,
|
||||||
int * n_past) {
|
int * n_past) {
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
static std::string ret;
|
static std::string ret;
|
||||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||||
ret = "</s>";
|
ret = "</s>";
|
||||||
@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
|
||||||
if (!ctx_sampling) {
|
if (!smpl) {
|
||||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string response = "";
|
std::string response = "";
|
||||||
for (int i = 0; i < max_tgt_len; i++) {
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||||
response += tmp;
|
response += tmp;
|
||||||
if (strcmp(tmp, "</s>") == 0) break;
|
if (strcmp(tmp, "</s>") == 0) break;
|
||||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||||
@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
|
|||||||
// process the prompt
|
// process the prompt
|
||||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||||
|
|
||||||
llama_print_timings(ctx_llava->ctx_llama);
|
llama_print_timings(ctx_llava->ctx_llama, nullptr);
|
||||||
llava_image_embed_free(image_embed);
|
llava_image_embed_free(image_embed);
|
||||||
ctx_llava->model = NULL;
|
ctx_llava->model = NULL;
|
||||||
llava_free(ctx_llava);
|
llava_free(ctx_llava);
|
||||||
@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
|
|||||||
// process the prompt
|
// process the prompt
|
||||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||||
|
|
||||||
llama_print_timings(ctx_llava->ctx_llama);
|
llama_print_timings(ctx_llava->ctx_llama, nullptr);
|
||||||
llava_image_embed_free(image_embed);
|
llava_image_embed_free(image_embed);
|
||||||
ctx_llava->model = NULL;
|
ctx_llava->model = NULL;
|
||||||
llava_free(ctx_llava);
|
llava_free(ctx_llava);
|
||||||
|
@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
|||||||
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
|
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
static const char * sample(struct llama_sampling * smpl,
|
||||||
struct llama_context * ctx_llama,
|
struct llama_context * ctx_llama,
|
||||||
int * n_past) {
|
int * n_past) {
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
static std::string ret;
|
static std::string ret;
|
||||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||||
ret = "</s>";
|
ret = "</s>";
|
||||||
@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
|
|||||||
return ctx_llava;
|
return ctx_llava;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||||
std::string user_prompt = prompt;
|
std::string user_prompt = prompt;
|
||||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||||
if (!is_first) {
|
if (!is_first) {
|
||||||
@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
|
|||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
|
||||||
return ctx_sampling;
|
return smpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
|
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){
|
||||||
|
|
||||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
|
|||||||
if (!params.prompt.empty()) {
|
if (!params.prompt.empty()) {
|
||||||
LOG_TEE("<user>%s\n", params.prompt.c_str());
|
LOG_TEE("<user>%s\n", params.prompt.c_str());
|
||||||
LOG_TEE("<assistant>");
|
LOG_TEE("<assistant>");
|
||||||
auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
|
auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
std::string response = "";
|
std::string response = "";
|
||||||
bool have_tmp = false;
|
bool have_tmp = false;
|
||||||
for (int i = 0; i < max_tgt_len; i++) {
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
|
auto tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||||
response += tmp;
|
response += tmp;
|
||||||
if (strcmp(tmp, "</s>") == 0){
|
if (strcmp(tmp, "</s>") == 0){
|
||||||
if(!have_tmp)continue;
|
if(!have_tmp)continue;
|
||||||
@ -296,18 +296,18 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
}else {
|
}else {
|
||||||
while (true) {
|
while (true) {
|
||||||
LOG_TEE("<user>");
|
LOG_TEE("<user>");
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::getline(std::cin, prompt);
|
std::getline(std::cin, prompt);
|
||||||
LOG_TEE("<assistant>");
|
LOG_TEE("<assistant>");
|
||||||
auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
|
auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
std::string response = "";
|
std::string response = "";
|
||||||
for (int i = 0; i < max_tgt_len; i++) {
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
|
auto tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||||
response += tmp;
|
response += tmp;
|
||||||
if (strcmp(tmp, "</s>") == 0) break;
|
if (strcmp(tmp, "</s>") == 0) break;
|
||||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||||
@ -315,11 +315,11 @@ int main(int argc, char ** argv) {
|
|||||||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
llama_print_timings(ctx_llava->ctx_llama);
|
llama_print_timings(ctx_llava->ctx_llama, nullptr);
|
||||||
|
|
||||||
ctx_llava->model = NULL;
|
ctx_llava->model = NULL;
|
||||||
llava_free(ctx_llava);
|
llava_free(ctx_llava);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||||
|
|
||||||
// target model sampling context
|
// target model sampling context
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
|
||||||
|
|
||||||
// verification n-grams
|
// verification n-grams
|
||||||
std::vector<ngram_data> ngrams_cur(G);
|
std::vector<ngram_data> ngrams_cur(G);
|
||||||
@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// sample first token
|
// sample first token
|
||||||
{
|
{
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
id = llama_sampling_sample(smpl, ctx, 0);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
|
|
||||||
{
|
{
|
||||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||||
@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample the next token
|
// sample the next token
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
id = llama_sampling_sample(smpl, ctx, i_batch);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
|
|
||||||
// print
|
// print
|
||||||
{
|
{
|
||||||
@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
|
|||||||
if (v == 0) {
|
if (v == 0) {
|
||||||
// sample from the last level
|
// sample from the last level
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
@ -468,10 +467,10 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("n_predict = %d\n", n_predict);
|
LOG_TEE("n_predict = %d\n", n_predict);
|
||||||
LOG_TEE("n_accept = %d\n", n_accept);
|
LOG_TEE("n_accept = %d\n", n_accept);
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, smpl);
|
||||||
|
|
||||||
llama_kv_cache_view_free(&kvc_view);
|
llama_kv_cache_view_free(&kvc_view);
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
@ -3,13 +3,11 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "ngram-cache.h"
|
#include "ngram-cache.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
int main(int argc, char ** argv){
|
int main(int argc, char ** argv){
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
@ -106,7 +104,7 @@ int main(int argc, char ** argv){
|
|||||||
|
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
|
||||||
|
|
||||||
std::vector<llama_token> draft;
|
std::vector<llama_token> draft;
|
||||||
|
|
||||||
@ -130,9 +128,9 @@ int main(int argc, char ** argv){
|
|||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
llama_token id = llama_sampling_sample(smpl, ctx, i_dft);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(smpl, id, true);
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||||
|
|
||||||
@ -241,9 +239,9 @@ int main(int argc, char ** argv){
|
|||||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||||
|
|
||||||
LOG_TEE("\ntarget:\n");
|
LOG_TEE("\ntarget:\n");
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, smpl);
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
llama_batch_free(batch_tgt);
|
llama_batch_free(batch_tgt);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
@ -33,6 +33,7 @@
|
|||||||
|
|
||||||
static llama_context ** g_ctx;
|
static llama_context ** g_ctx;
|
||||||
static llama_model ** g_model;
|
static llama_model ** g_model;
|
||||||
|
static llama_sampling ** g_smpl;
|
||||||
static gpt_params * g_params;
|
static gpt_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
|
|||||||
} else {
|
} else {
|
||||||
console::cleanup();
|
console::cleanup();
|
||||||
printf("\n");
|
printf("\n");
|
||||||
llama_print_timings(*g_ctx);
|
llama_print_timings(*g_ctx, *g_smpl);
|
||||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||||
_exit(130);
|
_exit(130);
|
||||||
}
|
}
|
||||||
@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
|||||||
|
|
||||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||||
llama_chat_msg new_msg{role, content};
|
llama_chat_msg new_msg{role, content};
|
||||||
auto formatted = llama_chat_format_single(
|
auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
||||||
chat_msgs.push_back({role, content});
|
chat_msgs.push_back({role, content});
|
||||||
LOG("formatted: %s\n", formatted.c_str());
|
LOG("formatted: %s\n", formatted.c_str());
|
||||||
return formatted;
|
return formatted;
|
||||||
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_params & sparams = params.sparams;
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("main", "log"));
|
log_set_target(log_filename_generator("main", "log"));
|
||||||
@ -183,27 +183,23 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
print_build_info();
|
||||||
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
|
|
||||||
LOG("%s: llama backend init\n", __func__);
|
LOG("%s: llama backend init\n", __func__);
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
llama_model * model;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx;
|
llama_context * ctx = nullptr;
|
||||||
llama_context * ctx_guidance = NULL;
|
llama_sampling * smpl = nullptr;
|
||||||
|
|
||||||
std::vector<llama_chat_msg> chat_msgs;
|
std::vector<llama_chat_msg> chat_msgs;
|
||||||
|
|
||||||
g_model = &model;
|
g_model = &model;
|
||||||
g_ctx = &ctx;
|
g_ctx = &ctx;
|
||||||
|
g_smpl = &smpl;
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
@ -211,10 +207,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
model = llama_init.model;
|
model = llama_init.model;
|
||||||
ctx = llama_init.context;
|
ctx = llama_init.context;
|
||||||
if (sparams.cfg_scale > 1.f) {
|
|
||||||
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
|
|
||||||
ctx_guidance = llama_new_context_with_model(model, lparams);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_TEE("%s: error: unable to load model\n", __func__);
|
LOG_TEE("%s: error: unable to load model\n", __func__);
|
||||||
@ -251,9 +243,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
|
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
|
||||||
if (ctx_guidance) {
|
|
||||||
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int n_ctx_train = llama_n_ctx_train(model);
|
const int n_ctx_train = llama_n_ctx_train(model);
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
@ -337,24 +326,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tokenize negative prompt
|
// Tokenize negative prompt
|
||||||
std::vector<llama_token> guidance_inp;
|
|
||||||
int guidance_offset = 0;
|
|
||||||
int original_prompt_len = 0;
|
|
||||||
if (ctx_guidance) {
|
|
||||||
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
|
|
||||||
|
|
||||||
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
|
|
||||||
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
|
|
||||||
|
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
|
||||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
|
||||||
|
|
||||||
original_prompt_len = original_inp.size();
|
|
||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
|
||||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
|
||||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||||
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
||||||
return 1;
|
return 1;
|
||||||
@ -421,15 +392,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_guidance) {
|
|
||||||
LOG_TEE("\n");
|
|
||||||
LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
|
|
||||||
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
|
|
||||||
for (int i = 0; i < (int) guidance_inp.size(); i++) {
|
|
||||||
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.n_keep > add_bos) {
|
if (params.n_keep > add_bos) {
|
||||||
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
|
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
|
||||||
for (int i = 0; i < params.n_keep; i++) {
|
for (int i = 0; i < params.n_keep; i++) {
|
||||||
@ -495,8 +457,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
|
||||||
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
|
|
||||||
// group-attention state
|
// group-attention state
|
||||||
@ -543,7 +505,6 @@ int main(int argc, char ** argv) {
|
|||||||
int n_remain = params.n_predict;
|
int n_remain = params.n_predict;
|
||||||
int n_consumed = 0;
|
int n_consumed = 0;
|
||||||
int n_session_consumed = 0;
|
int n_session_consumed = 0;
|
||||||
int n_past_guidance = 0;
|
|
||||||
|
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||||
@ -555,7 +516,6 @@ int main(int argc, char ** argv) {
|
|||||||
display = params.display_prompt;
|
display = params.display_prompt;
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> embd_guidance;
|
|
||||||
|
|
||||||
// tokenized antiprompts
|
// tokenized antiprompts
|
||||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||||
@ -565,8 +525,8 @@ int main(int argc, char ** argv) {
|
|||||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
smpl = llama_sampling_init(model, sparams);
|
||||||
if (!ctx_sampling) {
|
if (!smpl) {
|
||||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
@ -612,7 +572,7 @@ int main(int argc, char ** argv) {
|
|||||||
// if we run out of context:
|
// if we run out of context:
|
||||||
// - take the n_keep first tokens from the original prompt (via n_past)
|
// - take the n_keep first tokens from the original prompt (via n_past)
|
||||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
||||||
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
|
if (n_past + (int) embd.size() >= n_ctx) {
|
||||||
if (params.n_predict == -2) {
|
if (params.n_predict == -2) {
|
||||||
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||||
break;
|
break;
|
||||||
@ -629,11 +589,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
if (ctx_guidance) {
|
LOG("after swap: n_past = %d\n", n_past);
|
||||||
n_past_guidance -= n_discard;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
|
||||||
|
|
||||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||||
|
|
||||||
@ -686,46 +642,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate tokens in batches
|
|
||||||
// embd is typically prepared beforehand to fit within a batch, but not always
|
|
||||||
if (ctx_guidance) {
|
|
||||||
int input_size = 0;
|
|
||||||
llama_token * input_buf = NULL;
|
|
||||||
|
|
||||||
if (n_past_guidance < (int) guidance_inp.size()) {
|
|
||||||
// Guidance context should have the same data with these modifications:
|
|
||||||
//
|
|
||||||
// * Replace the initial prompt
|
|
||||||
// * Shift everything by guidance_offset
|
|
||||||
embd_guidance = guidance_inp;
|
|
||||||
if (embd.begin() + original_prompt_len < embd.end()) {
|
|
||||||
embd_guidance.insert(
|
|
||||||
embd_guidance.end(),
|
|
||||||
embd.begin() + original_prompt_len,
|
|
||||||
embd.end()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
input_buf = embd_guidance.data();
|
|
||||||
input_size = embd_guidance.size();
|
|
||||||
|
|
||||||
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
|
|
||||||
} else {
|
|
||||||
input_buf = embd.data();
|
|
||||||
input_size = embd.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
|
||||||
int n_eval = std::min(input_size - i, params.n_batch);
|
|
||||||
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
|
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
n_past_guidance += n_eval;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
||||||
int n_eval = (int) embd.size() - i;
|
int n_eval = (int) embd.size() - i;
|
||||||
if (n_eval > params.n_batch) {
|
if (n_eval > params.n_batch) {
|
||||||
@ -755,7 +671,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
embd.clear();
|
embd.clear();
|
||||||
embd_guidance.clear();
|
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
// optionally save the session on first sample (for faster prompt loading next time)
|
// optionally save the session on first sample (for faster prompt loading next time)
|
||||||
@ -766,11 +681,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
llama_sampling_accept(smpl, id, /* apply_grammar= */ true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
@ -789,7 +704,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||||
// for the prompt, we don't apply grammar rules
|
// for the prompt, we don't apply grammar rules
|
||||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
|
llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
@ -832,7 +747,7 @@ int main(int argc, char ** argv) {
|
|||||||
// check for reverse prompt in the last n_prev tokens
|
// check for reverse prompt in the last n_prev tokens
|
||||||
if (!params.antiprompt.empty()) {
|
if (!params.antiprompt.empty()) {
|
||||||
const int n_prev = 32;
|
const int n_prev = 32;
|
||||||
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev);
|
||||||
|
|
||||||
is_antiprompt = false;
|
is_antiprompt = false;
|
||||||
// Check if each of the reverse prompts appears at the end of the output.
|
// Check if each of the reverse prompts appears at the end of the output.
|
||||||
@ -854,7 +769,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check for reverse prompt using special tokens
|
// check for reverse prompt using special tokens
|
||||||
llama_token last_token = llama_sampling_last(ctx_sampling);
|
llama_token last_token = llama_sampling_last(smpl);
|
||||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||||
if (ids.size() == 1 && last_token == ids[0]) {
|
if (ids.size() == 1 && last_token == ids[0]) {
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -871,7 +786,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deal with end of generation tokens in interactive mode
|
// deal with end of generation tokens in interactive mode
|
||||||
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
|
||||||
LOG("found an EOG token\n");
|
LOG("found an EOG token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -892,7 +807,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// if current token is not EOG, we add it to current assistant message
|
// if current token is not EOG, we add it to current assistant message
|
||||||
if (params.conversation) {
|
if (params.conversation) {
|
||||||
auto id = llama_sampling_last(ctx_sampling);
|
auto id = llama_sampling_last(smpl);
|
||||||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -988,7 +903,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
llama_sampling_reset(ctx_sampling);
|
llama_sampling_reset(smpl);
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
@ -1013,14 +928,13 @@ int main(int argc, char ** argv) {
|
|||||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, smpl);
|
||||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||||
|
|
||||||
if (ctx_guidance) { llama_free(ctx_guidance); }
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
ggml_threadpool_free(threadpool);
|
ggml_threadpool_free(threadpool);
|
||||||
|
@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
|
|||||||
|
|
||||||
struct client {
|
struct client {
|
||||||
~client() {
|
~client() {
|
||||||
if (ctx_sampling) {
|
if (smpl) {
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ struct client {
|
|||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string response;
|
std::string response;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
struct llama_sampling * smpl = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void print_date_time() {
|
static void print_date_time() {
|
||||||
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||||||
for (size_t i = 0; i < clients.size(); ++i) {
|
for (size_t i = 0; i < clients.size(); ++i) {
|
||||||
auto & client = clients[i];
|
auto & client = clients[i];
|
||||||
client.id = i;
|
client.id = i;
|
||||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
client.smpl = llama_sampling_init(model, params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokens_system;
|
std::vector<llama_token> tokens_system;
|
||||||
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
|
|||||||
client.prompt = client.input + "\nAssistant:";
|
client.prompt = client.input + "\nAssistant:";
|
||||||
client.response = "";
|
client.response = "";
|
||||||
|
|
||||||
llama_sampling_reset(client.ctx_sampling);
|
llama_sampling_reset(client.smpl);
|
||||||
|
|
||||||
// do not prepend BOS because we have a system prompt!
|
// do not prepend BOS because we have a system prompt!
|
||||||
std::vector<llama_token> tokens_prompt;
|
std::vector<llama_token> tokens_prompt;
|
||||||
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
|
|||||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i);
|
||||||
|
|
||||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(client.smpl, id, true);
|
||||||
|
|
||||||
if (client.n_decoded == 1) {
|
if (client.n_decoded == 1) {
|
||||||
// start measuring generation time after the first token to make sure all concurrent clients
|
// start measuring generation time after the first token to make sure all concurrent clients
|
||||||
@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
// TODO: print sampling/grammar timings for all clients
|
||||||
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
@ -26,8 +26,6 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
|
|
||||||
|
|
||||||
int n_junk = params.n_junk;
|
int n_junk = params.n_junk;
|
||||||
int n_keep = params.n_keep;
|
int n_keep = params.n_keep;
|
||||||
int n_grp = params.grp_attn_n;
|
int n_grp = params.grp_attn_n;
|
||||||
@ -80,12 +78,13 @@ int main(int argc, char ** argv) {
|
|||||||
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<llama_token> tokens_list;
|
std::vector<llama_token> tokens_list;
|
||||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
@ -217,20 +216,12 @@ int main(int argc, char ** argv) {
|
|||||||
while (n_cur <= n_len) {
|
while (n_cur <= n_len) {
|
||||||
// sample the next token
|
// sample the next token
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
// sample the most likely token
|
// sample the most likely token
|
||||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||||
|
|
||||||
// is it an end of generation?
|
// is it an end of generation?
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||||
@ -267,12 +258,13 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
llama_sampling_free(smpl);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
print_build_info();
|
print_build_info();
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
@ -2054,7 +2048,7 @@ int main(int argc, char ** argv) {
|
|||||||
results = perplexity(ctx, params, n_ctx);
|
results = perplexity(ctx, params, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
write_logfile(ctx, params, model, results);
|
write_logfile(ctx, params, model, results);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#define LLAMA_API_INTERNAL
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-impl.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -320,7 +320,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
auto cparams = llama_context_default_params();
|
auto cparams = llama_context_default_params();
|
||||||
cparams.n_ctx = 256;
|
cparams.n_ctx = 256;
|
||||||
cparams.seed = 1;
|
|
||||||
|
|
||||||
ctx = llama_new_context_with_model(model, cparams);
|
ctx = llama_new_context_with_model(model, cparams);
|
||||||
|
|
||||||
|
@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
|
llama_print_timings(ctx, nullptr);
|
||||||
llama_batch_free(query_batch);
|
llama_batch_free(query_batch);
|
||||||
llama_print_timings(ctx);
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <chrono>
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
params.prompt = "The quick brown fox";
|
params.prompt = "The quick brown fox";
|
||||||
|
params.sparams.seed = 1234;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
gpt_params_print_usage(argc, argv, params);
|
gpt_params_print_usage(argc, argv, params);
|
||||||
@ -38,6 +38,11 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_sampling_params sparams = llama_sampling_default_params();
|
||||||
|
sparams.seed = params.sparams.seed;
|
||||||
|
|
||||||
|
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||||
|
|
||||||
// tokenize prompt
|
// tokenize prompt
|
||||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
@ -64,16 +69,11 @@ int main(int argc, char ** argv) {
|
|||||||
printf("\nfirst run: %s", params.prompt.c_str());
|
printf("\nfirst run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
for (auto i = 0; i < params.n_predict; i++) {
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
auto * logits = llama_get_logits(ctx);
|
const auto * logits = llama_get_logits(ctx);
|
||||||
auto n_vocab = llama_n_vocab(model);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
auto next_token = llama_sample_token(ctx, &candidates_p);
|
|
||||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
@ -96,6 +96,8 @@ int main(int argc, char ** argv) {
|
|||||||
// make new context
|
// make new context
|
||||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||||
|
|
||||||
|
llama_sampling * smpl2 = llama_sampling_init(model, sparams);
|
||||||
|
|
||||||
printf("\nsecond run: %s", params.prompt.c_str());
|
printf("\nsecond run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
@ -124,15 +126,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// second run
|
// second run
|
||||||
for (auto i = 0; i < params.n_predict; i++) {
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
auto * logits = llama_get_logits(ctx2);
|
const auto * logits = llama_get_logits(ctx2);
|
||||||
auto n_vocab = llama_n_vocab(model);
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl2, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
auto next_token = llama_sample_token(ctx2, &candidates_p);
|
|
||||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
@ -157,7 +155,9 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make new context
|
// make new context
|
||||||
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||||
|
|
||||||
|
llama_sampling * smpl3 = llama_sampling_init(model, sparams);
|
||||||
|
|
||||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
@ -215,15 +215,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// third run with seq 1 instead of 0
|
// third run with seq 1 instead of 0
|
||||||
for (auto i = 0; i < params.n_predict; i++) {
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
auto * logits = llama_get_logits(ctx3);
|
const auto * logits = llama_get_logits(ctx3);
|
||||||
auto n_vocab = llama_n_vocab(model);
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl3, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
auto next_token = llama_sample_token(ctx3, &candidates_p);
|
|
||||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
@ -240,6 +236,10 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
llama_sampling_free(smpl);
|
||||||
|
llama_sampling_free(smpl2);
|
||||||
|
llama_sampling_free(smpl3);
|
||||||
|
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -470,8 +470,6 @@ node index.js
|
|||||||
|
|
||||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
|
|
||||||
|
|
||||||
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
||||||
|
|
||||||
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
||||||
@ -724,7 +722,6 @@ Example:
|
|||||||
"stopping_word": ""
|
"stopping_word": ""
|
||||||
},
|
},
|
||||||
"penalize_nl": true,
|
"penalize_nl": true,
|
||||||
"penalty_prompt_tokens": [],
|
|
||||||
"presence_penalty": 0.0,
|
"presence_penalty": 0.0,
|
||||||
"prompt": "Say hello to llama.cpp",
|
"prompt": "Say hello to llama.cpp",
|
||||||
"repeat_last_n": 64,
|
"repeat_last_n": 64,
|
||||||
@ -748,8 +745,7 @@ Example:
|
|||||||
"tfs_z": 1.0,
|
"tfs_z": 1.0,
|
||||||
"top_k": 40,
|
"top_k": 40,
|
||||||
"top_p": 0.949999988079071,
|
"top_p": 0.949999988079071,
|
||||||
"typical_p": 1.0,
|
"typical_p": 1.0
|
||||||
"use_penalty_prompt_tokens": false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
@ -173,11 +172,13 @@ struct server_slot {
|
|||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
llama_token sampled;
|
|
||||||
struct llama_sampling_params sparams;
|
|
||||||
llama_sampling_context * ctx_sampling = nullptr;
|
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
struct gpt_sampling_params sparams;
|
||||||
|
|
||||||
|
llama_token sampled;
|
||||||
|
llama_sampling * smpl = nullptr;
|
||||||
|
|
||||||
int32_t ga_i = 0; // group-attention state
|
int32_t ga_i = 0; // group-attention state
|
||||||
int32_t ga_n = 1; // group-attention factor
|
int32_t ga_n = 1; // group-attention factor
|
||||||
int32_t ga_w = 512; // group-attention width
|
int32_t ga_w = 512; // group-attention width
|
||||||
@ -636,8 +637,8 @@ struct server_context {
|
|||||||
|
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ctx_sampling != nullptr) {
|
if (slot.smpl != nullptr) {
|
||||||
llama_sampling_free(slot.ctx_sampling);
|
llama_sampling_free(slot.smpl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -864,8 +865,8 @@ struct server_context {
|
|||||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||||
llama_sampling_params default_sparams = params.sparams;
|
auto default_sparams = params.sparams;
|
||||||
auto & data = task.data;
|
const auto & data = task.data;
|
||||||
|
|
||||||
if (data.count("__oaicompat") != 0) {
|
if (data.count("__oaicompat") != 0) {
|
||||||
slot.oaicompat = true;
|
slot.oaicompat = true;
|
||||||
@ -882,7 +883,7 @@ struct server_context {
|
|||||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
@ -904,7 +905,8 @@ struct server_context {
|
|||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
} else if (data.contains("json_schema") && !data.contains("grammar")) {
|
}
|
||||||
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||||
@ -954,56 +956,11 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// penalize user-provided tokens
|
|
||||||
{
|
|
||||||
slot.sparams.penalty_prompt_tokens.clear();
|
|
||||||
slot.sparams.use_penalty_prompt_tokens = false;
|
|
||||||
|
|
||||||
const auto & penalty_prompt = data.find("penalty_prompt");
|
|
||||||
|
|
||||||
if (penalty_prompt != data.end()) {
|
|
||||||
if (penalty_prompt->is_string()) {
|
|
||||||
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
||||||
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
|
||||||
|
|
||||||
if (slot.params.n_predict > 0) {
|
|
||||||
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
||||||
}
|
|
||||||
slot.sparams.use_penalty_prompt_tokens = true;
|
|
||||||
|
|
||||||
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
||||||
{"id_slot", slot.id},
|
|
||||||
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
else if (penalty_prompt->is_array()) {
|
|
||||||
const auto n_tokens = penalty_prompt->size();
|
|
||||||
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(model);
|
|
||||||
for (const auto & penalty_token : *penalty_prompt) {
|
|
||||||
if (penalty_token.is_number_integer()) {
|
|
||||||
const auto tok = penalty_token.get<llama_token>();
|
|
||||||
if (tok >= 0 && tok < n_vocab) {
|
|
||||||
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slot.sparams.use_penalty_prompt_tokens = true;
|
|
||||||
|
|
||||||
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
||||||
{"id_slot", slot.id},
|
|
||||||
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
slot.sparams.logit_bias.clear();
|
slot.sparams.logit_bias.clear();
|
||||||
|
|
||||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & logit_bias = data.find("logit_bias");
|
const auto & logit_bias = data.find("logit_bias");
|
||||||
@ -1024,12 +981,12 @@ struct server_context {
|
|||||||
if (el[0].is_number_integer()) {
|
if (el[0].is_number_integer()) {
|
||||||
llama_token tok = el[0].get<llama_token>();
|
llama_token tok = el[0].get<llama_token>();
|
||||||
if (tok >= 0 && tok < n_vocab) {
|
if (tok >= 0 && tok < n_vocab) {
|
||||||
slot.sparams.logit_bias[tok] = bias;
|
slot.sparams.logit_bias.push_back({tok, bias});
|
||||||
}
|
}
|
||||||
} else if (el[0].is_string()) {
|
} else if (el[0].is_string()) {
|
||||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||||
for (auto tok : toks) {
|
for (auto tok : toks) {
|
||||||
slot.sparams.logit_bias[tok] = bias;
|
slot.sparams.logit_bias.push_back({tok, bias});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1051,26 +1008,27 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto & samplers_sequence = data.find("samplers");
|
const auto & samplers = data.find("samplers");
|
||||||
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
|
if (samplers != data.end() && samplers->is_array()) {
|
||||||
std::vector<std::string> sampler_names;
|
std::vector<std::string> sampler_names;
|
||||||
for (const auto & sampler_name : *samplers_sequence) {
|
for (const auto & sampler_name : *samplers) {
|
||||||
if (sampler_name.is_string()) {
|
if (sampler_name.is_string()) {
|
||||||
sampler_names.emplace_back(sampler_name);
|
sampler_names.emplace_back(sampler_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
|
slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false);
|
||||||
} else {
|
} else {
|
||||||
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
slot.sparams.samplers = default_sparams.samplers;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
if (slot.ctx_sampling != nullptr) {
|
if (slot.smpl != nullptr) {
|
||||||
llama_sampling_free(slot.ctx_sampling);
|
llama_sampling_free(slot.smpl);
|
||||||
}
|
}
|
||||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
|
||||||
if (slot.ctx_sampling == nullptr) {
|
slot.smpl = llama_sampling_init(model, slot.sparams);
|
||||||
|
if (slot.smpl == nullptr) {
|
||||||
// for now, the only error that may happen here is invalid grammar
|
// for now, the only error that may happen here is invalid grammar
|
||||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
@ -1159,11 +1117,6 @@ struct server_context {
|
|||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
||||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
||||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
// check if there is incomplete UTF-8 character at the end
|
||||||
bool incomplete = false;
|
bool incomplete = false;
|
||||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
||||||
@ -1281,13 +1234,10 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
json get_formated_generation(const server_slot & slot) const {
|
json get_formated_generation(const server_slot & slot) const {
|
||||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
std::vector<std::string> samplers;
|
||||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
samplers.reserve(slot.sparams.samplers.size());
|
||||||
|
for (const auto & sampler : slot.sparams.samplers) {
|
||||||
std::vector<std::string> samplers_sequence;
|
samplers.emplace_back(llama_sampling_type_to_str(sampler));
|
||||||
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
||||||
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
|
||||||
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
@ -1302,13 +1252,11 @@ struct server_context {
|
|||||||
{"top_p", slot.sparams.top_p},
|
{"top_p", slot.sparams.top_p},
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"tfs_z", slot.sparams.tfs_z},
|
||||||
{"typical_p", slot.sparams.typical_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
||||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
@ -1317,13 +1265,13 @@ struct server_context {
|
|||||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_discard", slot.params.n_discard},
|
{"n_discard", slot.params.n_discard},
|
||||||
{"ignore_eos", ignore_eos},
|
{"ignore_eos", slot.sparams.ignore_eos},
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
{"logit_bias", slot.sparams.logit_bias},
|
//{"logit_bias", slot.sparams.logit_bias},
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
{"samplers", samplers_sequence}
|
{"samplers", samplers},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2139,7 +2087,7 @@ struct server_context {
|
|||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.smpl);
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
if (!slot.params.cache_prompt) {
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
@ -2152,7 +2100,7 @@ struct server_context {
|
|||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (int i = 0; i < slot.n_past; ++i) {
|
for (int i = 0; i < slot.n_past; ++i) {
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2205,7 +2153,7 @@ struct server_context {
|
|||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
slot.ga_i = 0;
|
slot.ga_i = 0;
|
||||||
// TODO: is the system prompt ever in the sampling context?
|
// TODO: is the system prompt ever in the sampling context?
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
// remove the non-common part from the cache
|
||||||
@ -2382,9 +2330,9 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||||
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(slot.smpl, id, true);
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
if (slot.n_decoded == 1) {
|
if (slot.n_decoded == 1) {
|
||||||
@ -2393,35 +2341,18 @@ struct server_context {
|
|||||||
metrics.on_prompt_eval(slot);
|
metrics.on_prompt_eval(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
|
|
||||||
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
const auto * cur_p = llama_sampling_get_candidates(slot.smpl);
|
||||||
if (n_probs > 0) {
|
|
||||||
const size_t n_valid = slot.ctx_sampling->n_valid;
|
|
||||||
|
|
||||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
// TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643
|
||||||
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
// fix if necessary
|
||||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
||||||
}
|
|
||||||
|
|
||||||
if (slot.sparams.temp == 0.0f) {
|
|
||||||
// With greedy sampling the probabilities have possibly not been calculated.
|
|
||||||
for (size_t i = 0; i < n_probs; ++i) {
|
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur_p.data[i].id,
|
cur_p->data[i].id,
|
||||||
i == 0 ? 1.0f : 0.0f
|
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < n_probs; ++i) {
|
|
||||||
result.probs.push_back({
|
|
||||||
cur_p.data[i].id,
|
|
||||||
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -55,6 +55,8 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
|
|
||||||
std::vector<llama_token> tokens_list;
|
std::vector<llama_token> tokens_list;
|
||||||
@ -110,20 +112,12 @@ int main(int argc, char ** argv) {
|
|||||||
while (n_cur <= n_predict) {
|
while (n_cur <= n_predict) {
|
||||||
// sample the next token
|
// sample the next token
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
llama_sampling_set_logits(smpl, logits);
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
// sample the most likely token
|
// sample the most likely token
|
||||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||||
|
|
||||||
// is it an end of generation?
|
// is it an end of generation?
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||||
@ -160,12 +154,13 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx, nullptr);
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
llama_sampling_free(smpl);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ struct seq_draft {
|
|||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
std::vector<std::vector<llama_token_data>> dists;
|
std::vector<std::vector<llama_token_data>> dists;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling;
|
struct llama_sampling * smpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
@ -37,16 +37,16 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for probabilities to be computed even with temp = 0
|
||||||
|
params.sparams.n_probs = 16;
|
||||||
|
|
||||||
// max number of parallel drafting sequences (i.e. tree branches)
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||||
const float p_split = params.p_split;
|
const float p_split = params.p_split;
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
std::default_random_engine rng(params.sparams.seed);
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
std::default_random_engine rng(params.seed);
|
|
||||||
std::uniform_real_distribution<> u_dist;
|
std::uniform_real_distribution<> u_dist;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
@ -179,19 +179,15 @@ int main(int argc, char ** argv) {
|
|||||||
// used to determine end of generation
|
// used to determine end of generation
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
// target model sampling context
|
// target model sampling context (reuse the llama_context's sampling instance)
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams);
|
||||||
|
|
||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
|
||||||
if (params.sparams.temp == 0) {
|
|
||||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
// allocate llama_sampling for each draft sequence
|
||||||
|
drafts[s].smpl = llama_sampling_init(model_dft, params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||||
@ -234,9 +230,15 @@ int main(int argc, char ** argv) {
|
|||||||
if (params.sparams.temp > 0) {
|
if (params.sparams.temp > 0) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
|
|
||||||
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]));
|
||||||
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
|
||||||
float p_tgt = 0, p_dft = 0;
|
auto & dist_tgt = *llama_sampling_get_candidates(smpl);
|
||||||
|
|
||||||
|
llama_sampling_grammar(smpl, &dist_tgt);
|
||||||
|
llama_sampling_softmax(smpl, &dist_tgt);
|
||||||
|
|
||||||
|
float p_tgt = 0.0f;
|
||||||
|
float p_dft = 0.0f;
|
||||||
|
|
||||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||||
|
|
||||||
@ -278,7 +280,7 @@ int main(int argc, char ** argv) {
|
|||||||
accept = true;
|
accept = true;
|
||||||
token_id = drafts[s].tokens[i_dft];
|
token_id = drafts[s].tokens[i_dft];
|
||||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
llama_sampling_accept(smpl, token_id, true);
|
||||||
|
|
||||||
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||||
break;
|
break;
|
||||||
@ -332,8 +334,8 @@ int main(int argc, char ** argv) {
|
|||||||
// all drafted tokens were rejected
|
// all drafted tokens were rejected
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||||
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
token_id = llama_sampling_sample_dist(smpl, &dist_tgt);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
llama_sampling_accept(smpl, token_id, true);
|
||||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,11 +344,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
llama_sampling_accept(smpl, token_id, true);
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
|
||||||
|
|
||||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
|
||||||
@ -434,7 +436,7 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
|
llama_sampling_cp(smpl, drafts[0].smpl);
|
||||||
|
|
||||||
int n_seq_cur = 1;
|
int n_seq_cur = 1;
|
||||||
int n_past_cur = n_past_dft;
|
int n_past_cur = n_past_dft;
|
||||||
@ -463,20 +465,20 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
||||||
|
|
||||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl);
|
||||||
|
|
||||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
||||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> sa(1, s);
|
std::vector<int> sa(1, s);
|
||||||
|
|
||||||
// attempt to split the branch if the probability is high enough
|
// attempt to split the branch if the probability is high enough
|
||||||
for (int f = 1; f < 8; ++f) {
|
for (int f = 1; f < 8; ++f) {
|
||||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
|
||||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||||
@ -503,7 +505,7 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
|
llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl);
|
||||||
|
|
||||||
sa.push_back(n_seq_cur);
|
sa.push_back(n_seq_cur);
|
||||||
|
|
||||||
@ -515,15 +517,15 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// add drafted token for each sequence
|
// add drafted token for each sequence
|
||||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||||
const llama_token id = cur_p[is].id;
|
const llama_token id = cur_p->data[is].id;
|
||||||
|
|
||||||
const int s = sa[is];
|
const int s = sa[is];
|
||||||
|
|
||||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
llama_sampling_accept(drafts[s].smpl, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
// save cur_p.data into drafts[s].dists
|
// save cur_p.data into drafts[s].dists
|
||||||
drafts[s].dists.push_back(cur_p);
|
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
@ -594,14 +596,15 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||||
|
|
||||||
LOG_TEE("\ndraft:\n");
|
LOG_TEE("\ndraft:\n");
|
||||||
llama_print_timings(ctx_dft);
|
// TODO: print sampling/grammar timings for all drafts
|
||||||
|
llama_print_timings(ctx_dft, nullptr);
|
||||||
|
|
||||||
LOG_TEE("\ntarget:\n");
|
LOG_TEE("\ntarget:\n");
|
||||||
llama_print_timings(ctx_tgt);
|
llama_print_timings(ctx_tgt, smpl);
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(smpl);
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
llama_sampling_free(drafts[s].ctx_sampling);
|
llama_sampling_free(drafts[s].smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch_dft);
|
llama_batch_free(batch_dft);
|
||||||
|
395
include/llama.h
395
include/llama.h
@ -33,16 +33,21 @@
|
|||||||
|
|
||||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||||
|
|
||||||
|
// TODO: use everywhere in the implementation
|
||||||
|
#define LLAMA_TOKEN_NULL -1
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 8
|
#define LLAMA_SESSION_VERSION 9
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 2
|
#define LLAMA_STATE_SEQ_VERSION 2
|
||||||
|
|
||||||
|
#define LLAMA_MAX_SAMPLERS 16
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@ -53,8 +58,10 @@ extern "C" {
|
|||||||
// TODO: show sample usage
|
// TODO: show sample usage
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// struct llama_vocab; // TODO: add in the future
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
struct llama_sampling;
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
@ -199,6 +206,16 @@ extern "C" {
|
|||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_sampler_type {
|
||||||
|
LLAMA_SAMPLER_TYPE_NONE = 0,
|
||||||
|
LLAMA_SAMPLER_TYPE_TOP_K = 1,
|
||||||
|
LLAMA_SAMPLER_TYPE_TOP_P = 2,
|
||||||
|
LLAMA_SAMPLER_TYPE_MIN_P = 3,
|
||||||
|
LLAMA_SAMPLER_TYPE_TFS_Z = 4,
|
||||||
|
LLAMA_SAMPLER_TYPE_TYPICAL_P = 5,
|
||||||
|
LLAMA_SAMPLER_TYPE_TEMPERATURE = 6,
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
@ -206,6 +223,7 @@ extern "C" {
|
|||||||
} llama_token_data;
|
} llama_token_data;
|
||||||
|
|
||||||
typedef struct llama_token_data_array {
|
typedef struct llama_token_data_array {
|
||||||
|
// TODO: consider SoA
|
||||||
llama_token_data * data;
|
llama_token_data * data;
|
||||||
size_t size;
|
size_t size;
|
||||||
bool sorted;
|
bool sorted;
|
||||||
@ -300,7 +318,6 @@ extern "C" {
|
|||||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||||
// https://github.com/ggerganov/llama.cpp/pull/7544
|
// https://github.com/ggerganov/llama.cpp/pull/7544
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
uint32_t seed; // RNG seed, -1 for random
|
|
||||||
uint32_t n_ctx; // text context, 0 = from model
|
uint32_t n_ctx; // text context, 0 = from model
|
||||||
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
||||||
uint32_t n_ubatch; // physical maximum batch size
|
uint32_t n_ubatch; // physical maximum batch size
|
||||||
@ -328,7 +345,8 @@ extern "C" {
|
|||||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
|
// TODO: move at the end of the struct
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
@ -356,53 +374,56 @@ extern "C" {
|
|||||||
void * kv_overrides; // pointer to vector containing overrides
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
// grammar types
|
typedef struct llama_logit_bias {
|
||||||
struct llama_grammar;
|
llama_token token;
|
||||||
|
float bias;
|
||||||
|
} llama_logit_bias;
|
||||||
|
|
||||||
// grammar element type
|
// parameters for sampling the logits
|
||||||
enum llama_gretype {
|
typedef struct llama_sampling_params {
|
||||||
// end of rule definition
|
uint32_t seed; // the seed used to initialize llama_sampling_context
|
||||||
LLAMA_GRETYPE_END = 0,
|
int32_t n_prev; // number of previous tokens to remember
|
||||||
|
int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
|
int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
|
int32_t top_k; // <= 0 to use vocab size
|
||||||
|
float top_p; // 1.0 = disabled
|
||||||
|
float min_p; // 0.0 = disabled
|
||||||
|
float tfs_z; // 1.0 = disabled
|
||||||
|
float typ_p; // typical_p, 1.0 = disabled
|
||||||
|
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
|
float dynatemp_range; // 0.0 = disabled
|
||||||
|
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
|
int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float penalty_repeat; // 1.0 = disabled
|
||||||
|
float penalty_freq; // 0.0 = disabled
|
||||||
|
float penalty_present; // 0.0 = disabled
|
||||||
|
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float mirostat_tau; // target entropy
|
||||||
|
float mirostat_eta; // learning rate
|
||||||
|
|
||||||
// start of alternate definition for rule
|
// samplers
|
||||||
LLAMA_GRETYPE_ALT = 1,
|
int32_t n_samplers;
|
||||||
|
enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS];
|
||||||
|
|
||||||
// non-terminal element: reference to rule
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
LLAMA_GRETYPE_RULE_REF = 2,
|
bool penalize_nl; // consider newlines as a repeatable token
|
||||||
|
bool ignore_eos; // ignore the end-of-sequence token
|
||||||
// terminal element: character (code point)
|
} llama_sampling_params;
|
||||||
LLAMA_GRETYPE_CHAR = 3,
|
|
||||||
|
|
||||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
|
||||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
|
||||||
// be an inclusive range ([a-z])
|
|
||||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
|
||||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
|
||||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
|
||||||
|
|
||||||
// any character (.)
|
|
||||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct llama_grammar_element {
|
|
||||||
enum llama_gretype type;
|
|
||||||
uint32_t value; // Unicode code point or rule ID
|
|
||||||
} llama_grammar_element;
|
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
struct llama_timings {
|
struct llama_timings {
|
||||||
double t_start_ms;
|
double t_start_ms;
|
||||||
double t_end_ms;
|
double t_end_ms;
|
||||||
double t_load_ms;
|
double t_load_ms;
|
||||||
double t_sample_ms;
|
double t_sampling_ms;
|
||||||
|
double t_grammar_ms;
|
||||||
|
double t_accept_ms;
|
||||||
double t_p_eval_ms;
|
double t_p_eval_ms;
|
||||||
double t_eval_ms;
|
double t_eval_ms;
|
||||||
|
|
||||||
int32_t n_sample;
|
int32_t n_sampling;
|
||||||
|
int32_t n_grammar;
|
||||||
|
int32_t n_accept;
|
||||||
int32_t n_p_eval;
|
int32_t n_p_eval;
|
||||||
int32_t n_eval;
|
int32_t n_eval;
|
||||||
};
|
};
|
||||||
@ -419,6 +440,7 @@ extern "C" {
|
|||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
|
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
|
||||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||||
|
|
||||||
// Initialize the llama + ggml backend
|
// Initialize the llama + ggml backend
|
||||||
@ -445,6 +467,7 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||||
|
|
||||||
|
// TODO: rename to llama_init_from_model
|
||||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params);
|
struct llama_context_params params);
|
||||||
@ -460,23 +483,22 @@ extern "C" {
|
|||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
|
||||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
|
||||||
|
|
||||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||||
|
|
||||||
|
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||||
|
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||||
|
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||||
|
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||||
|
|
||||||
@ -704,7 +726,7 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
// Returns the *actual* size in bytes of the state
|
// Returns the *actual* size in bytes of the state
|
||||||
// (rng, logits, embedding and kv_cache)
|
// (logits, embedding and kv_cache)
|
||||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||||
@ -1006,160 +1028,132 @@ extern "C" {
|
|||||||
char * buf,
|
char * buf,
|
||||||
int32_t length);
|
int32_t length);
|
||||||
|
|
||||||
//
|
|
||||||
// Grammar
|
|
||||||
//
|
|
||||||
|
|
||||||
/// Initialize a llama_grammar.
|
|
||||||
///
|
|
||||||
/// @param rules The rule elements of the grammar to initialize.
|
|
||||||
/// @param n_rules The number of rules.
|
|
||||||
/// @param start_rule_index The index of the root rule (the starting point of the grammar).
|
|
||||||
/// @return The initialized llama_grammar or nullptr if initialization failed.
|
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
|
||||||
const llama_grammar_element ** rules,
|
|
||||||
size_t n_rules,
|
|
||||||
size_t start_rule_index);
|
|
||||||
|
|
||||||
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
/// @details Apply constraints from grammar
|
|
||||||
LLAMA_API void llama_grammar_sample(
|
|
||||||
const struct llama_grammar * grammar,
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates);
|
|
||||||
LLAMA_API DEPRECATED(void llama_sample_grammar(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const struct llama_grammar * grammar),
|
|
||||||
"use llama_grammar_sample instead");
|
|
||||||
|
|
||||||
/// @details Accepts the sampled token into the grammar
|
|
||||||
LLAMA_API void llama_grammar_accept_token(
|
|
||||||
struct llama_grammar * grammar,
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token token);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
//
|
//
|
||||||
|
|
||||||
// Sets the current rng seed.
|
// TODO: llama_model should become llama_vocab
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
|
||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
|
||||||
LLAMA_API void llama_sample_repetition_penalties(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
|
||||||
size_t penalty_last_n,
|
|
||||||
float penalty_repeat,
|
|
||||||
float penalty_freq,
|
|
||||||
float penalty_present);
|
|
||||||
|
|
||||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
|
||||||
/// @param logits Logits extracted from the original generation context.
|
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
|
||||||
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
|
||||||
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
// - clear prev token
|
||||||
LLAMA_API void llama_sample_apply_guidance(
|
// - reset grammar state
|
||||||
struct llama_context * ctx,
|
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
|
||||||
float * logits,
|
|
||||||
float * logits_guidance,
|
// Sampling parameter mutation
|
||||||
float scale);
|
// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
|
||||||
|
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||||
|
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
// Set the logits from which to sample.
|
||||||
|
// This call initializes the internal token candidates array.
|
||||||
|
// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
|
||||||
|
LLAMA_API void llama_sampling_set_logits(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
const float * logits);
|
||||||
|
|
||||||
|
/// @details Returns the current candidate tokens.
|
||||||
|
LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
|
||||||
|
struct llama_sampling * smpl);
|
||||||
|
|
||||||
|
// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
|
||||||
|
// Each function can accept an array of token candidates. If the candidates are not provided, the internal
|
||||||
|
// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
LLAMA_API void llama_sample_softmax(
|
LLAMA_API void llama_sampling_softmax(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API void llama_sample_top_k(
|
LLAMA_API void llama_sampling_top_k(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates);
|
||||||
int32_t k,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API void llama_sample_top_p(
|
LLAMA_API void llama_sampling_top_p(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates);
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||||
LLAMA_API void llama_sample_min_p(
|
LLAMA_API void llama_sampling_min_p(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates);
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||||
LLAMA_API void llama_sample_tail_free(
|
LLAMA_API void llama_sampling_tail_free(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates);
|
||||||
float z,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
LLAMA_API void llama_sample_typical(
|
LLAMA_API void llama_sampling_typical(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates);
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
/// @details Apply temperature and entropy
|
||||||
LLAMA_API void llama_sample_entropy(
|
LLAMA_API void llama_sampling_temp(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates_p,
|
llama_token_data_array * candidates);
|
||||||
float min_temp,
|
|
||||||
float max_temp,
|
|
||||||
float exponent_val);
|
|
||||||
|
|
||||||
LLAMA_API void llama_sample_temp(
|
/// @details Apply constraints from grammar
|
||||||
struct llama_context * ctx,
|
LLAMA_API void llama_sampling_grammar(
|
||||||
llama_token_data_array * candidates,
|
struct llama_sampling * smpl,
|
||||||
float temp);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
LLAMA_API void llama_sampling_penalties(
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
struct llama_sampling * smpl,
|
||||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
llama_token_data_array * candidates);
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
|
||||||
LLAMA_API llama_token llama_sample_token_mirostat(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float tau,
|
|
||||||
float eta,
|
|
||||||
int32_t m,
|
|
||||||
float * mu);
|
|
||||||
|
|
||||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
LLAMA_API llama_token llama_sampling_sample_mirostat(
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
struct llama_sampling * smpl,
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
llama_token_data_array * candidates);
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
|
||||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float tau,
|
|
||||||
float eta,
|
|
||||||
float * mu);
|
|
||||||
|
|
||||||
/// @details Selects the token with the highest probability.
|
/// @details Selects the token with the highest probability.
|
||||||
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
|
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
|
||||||
LLAMA_API llama_token llama_sample_token_greedy(
|
LLAMA_API llama_token llama_sampling_sample_greedy(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
/// @details Randomly selects a token from the candidates based on their probability distribution.
|
||||||
LLAMA_API llama_token llama_sample_token(
|
LLAMA_API llama_token llama_sampling_sample_dist(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
|
||||||
|
LLAMA_API llama_token llama_sampling_sample(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Accepts the sampled token into the sampling context.
|
||||||
|
/// - adds it to "prev" tokens
|
||||||
|
/// - updates the grammar state (if apply_grammar is true)
|
||||||
|
LLAMA_API void llama_sampling_accept(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token token,
|
||||||
|
bool apply_grammar);
|
||||||
|
|
||||||
|
/// @details Get the number of accepted tokens so far (max of n_prev)
|
||||||
|
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
|
||||||
|
|
||||||
|
/// @details Get the ith accepted token
|
||||||
|
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
|
||||||
|
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
|
||||||
|
LLAMA_API llama_token llama_sampling_prev(
|
||||||
|
const struct llama_sampling * smpl,
|
||||||
|
int32_t ith);
|
||||||
|
|
||||||
|
/// @details Get the last accepted token
|
||||||
|
/// Same as llama_sampling_prev(smpl, 0)
|
||||||
|
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
|
||||||
|
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model split
|
// Model split
|
||||||
//
|
//
|
||||||
@ -1177,8 +1171,8 @@ extern "C" {
|
|||||||
// Performance information
|
// Performance information
|
||||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl);
|
||||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl);
|
||||||
|
|
||||||
// Print system information
|
// Print system information
|
||||||
LLAMA_API const char * llama_print_system_info(void);
|
LLAMA_API const char * llama_print_system_info(void);
|
||||||
@ -1193,59 +1187,4 @@ extern "C" {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
|
||||||
#ifdef LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct ggml_tensor;
|
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
|
||||||
struct llama_context * ctx
|
|
||||||
);
|
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
|
||||||
uint32_t value; // bit value so far (unshifted)
|
|
||||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar_candidate {
|
|
||||||
size_t index;
|
|
||||||
const uint32_t * code_points;
|
|
||||||
llama_partial_utf8 partial_utf8;
|
|
||||||
};
|
|
||||||
|
|
||||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
|
||||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
|
||||||
|
|
||||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
|
||||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
|
||||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|
||||||
|
|
||||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
void llama_grammar_accept(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stacks & stacks,
|
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks);
|
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stack & stack,
|
|
||||||
const llama_grammar_candidates & candidates);
|
|
||||||
|
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
||||||
const std::string & src,
|
|
||||||
llama_partial_utf8 partial_start);
|
|
||||||
|
|
||||||
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
|
||||||
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
|
||||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
|
||||||
|
|
||||||
#endif // LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#endif // LLAMA_H
|
#endif // LLAMA_H
|
||||||
|
@ -4,10 +4,29 @@
|
|||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
//
|
||||||
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
// helpers
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
//
|
||||||
|
|
||||||
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
|
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||||
|
uint8_t highbits = first_byte >> 4;
|
||||||
|
int len = lookup[highbits];
|
||||||
|
uint8_t mask = (1 << (8 - len)) - 1;
|
||||||
|
uint32_t value = first_byte & mask;
|
||||||
|
const char * end = src + len; // may overrun!
|
||||||
|
const char * pos = src + 1;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
llama_partial_utf8 partial_start) {
|
llama_partial_utf8 partial_start) {
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
||||||
@ -67,12 +86,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|||||||
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
static bool is_digit_char(char c) {
|
||||||
return grammar->rules;
|
return '0' <= c && c <= '9';
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
static bool is_word_char(char c) {
|
||||||
return grammar->stacks;
|
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||||
|
const char * pos = src;
|
||||||
|
const char * end = src + size;
|
||||||
|
uint32_t value = 0;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value <<= 4;
|
||||||
|
char c = *pos;
|
||||||
|
if ('a' <= c && c <= 'f') {
|
||||||
|
value += c - 'a' + 10;
|
||||||
|
} else if ('A' <= c && c <= 'F') {
|
||||||
|
value += c - 'A' + 10;
|
||||||
|
} else if ('0' <= c && c <= '9') {
|
||||||
|
value += c - '0';
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != end) {
|
||||||
|
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_space(const char * src, bool newline_ok) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||||
|
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||||
|
if (*pos == '#') {
|
||||||
|
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_name(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_word_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_int(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_digit_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::runtime_error(std::string("expecting integer at ") + src);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||||
|
if (*src == '\\') {
|
||||||
|
switch (src[1]) {
|
||||||
|
case 'x': return parse_hex(src + 2, 2);
|
||||||
|
case 'u': return parse_hex(src + 2, 4);
|
||||||
|
case 'U': return parse_hex(src + 2, 8);
|
||||||
|
case 't': return std::make_pair('\t', src + 2);
|
||||||
|
case 'r': return std::make_pair('\r', src + 2);
|
||||||
|
case 'n': return std::make_pair('\n', src + 2);
|
||||||
|
case '\\':
|
||||||
|
case '"':
|
||||||
|
case '[':
|
||||||
|
case ']':
|
||||||
|
return std::make_pair(src[1], src + 2);
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||||
|
}
|
||||||
|
} else if (*src) {
|
||||||
|
return decode_utf8(src);
|
||||||
|
}
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||||
|
if (0x20 <= c && c <= 0x7f) {
|
||||||
|
fprintf(file, "%c", static_cast<char>(c));
|
||||||
|
} else {
|
||||||
|
// cop out of encoding UTF-8
|
||||||
|
fprintf(file, "<U+%04X>", c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_char_element(llama_grammar_element elem) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||||
|
for (auto elem : rule) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
||||||
|
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||||
|
}
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "(%u) ", elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
fprintf(file, "(\"");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
fprintf(file, "\") ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_rule(
|
||||||
|
FILE * file,
|
||||||
|
uint32_t rule_id,
|
||||||
|
const llama_grammar_rule & rule,
|
||||||
|
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||||
|
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||||
|
}
|
||||||
|
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||||
|
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||||
|
llama_grammar_element elem = rule[i];
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||||
|
std::to_string(i));
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
fprintf(file, "| ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
fprintf(file, "[");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
fprintf(file, "[^");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
fprintf(file, "-");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
fprintf(file, ".");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (is_char_element(elem)) {
|
||||||
|
switch (rule[i + 1].type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(file, "] ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// implementation
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
||||||
|
auto result = symbol_ids.emplace(std::string(src, len), next_id);
|
||||||
|
return result.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
||||||
|
symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||||
|
return next_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
|
||||||
|
if (rules.size() <= rule_id) {
|
||||||
|
rules.resize(rule_id + 1);
|
||||||
|
}
|
||||||
|
rules[rule_id] = rule;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_alternates(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested) {
|
||||||
|
llama_grammar_rule rule;
|
||||||
|
const char * pos = parse_sequence(src, rule_name, rule, is_nested);
|
||||||
|
while (*pos == '|') {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
pos = parse_sequence(pos, rule_name, rule, is_nested);
|
||||||
|
}
|
||||||
|
rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule(rule_id, rule);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_sequence(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
llama_grammar_rule & rule,
|
||||||
|
bool is_nested) {
|
||||||
|
size_t last_sym_start = rule.size();
|
||||||
|
const char * pos = src;
|
||||||
|
|
||||||
|
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||||
|
|
||||||
|
if (last_sym_start == rule.size()) {
|
||||||
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||||
|
// the following rewrite rules:
|
||||||
|
// S{m,n} --> S S S (m times) S'(n-m)
|
||||||
|
// S'(x) ::= S S'(x-1) |
|
||||||
|
// (... n-m definitions of these S' rules ...)
|
||||||
|
// S'(1) ::= S |
|
||||||
|
// S{m,} --> S S S (m times) S'
|
||||||
|
// S' ::= S S' |
|
||||||
|
// S* --> S{0,}
|
||||||
|
// --> S' ::= S S' |
|
||||||
|
// S+ --> S{1,}
|
||||||
|
// --> S S'
|
||||||
|
// S' ::= S S' |
|
||||||
|
// S? --> S{0,1}
|
||||||
|
// --> S'
|
||||||
|
// S' ::= S |
|
||||||
|
|
||||||
|
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||||
|
if (min_times == 0) {
|
||||||
|
rule.resize(last_sym_start);
|
||||||
|
} else {
|
||||||
|
// Repeat the previous elements (min_times - 1) times
|
||||||
|
for (int i = 1; i < min_times; i++) {
|
||||||
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t last_rec_rule_id = 0;
|
||||||
|
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||||
|
|
||||||
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
|
for (int i = 0; i < n_opt; i++) {
|
||||||
|
rec_rule.resize(prev_rule.size());
|
||||||
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
|
if (i > 0 || max_times < 0) {
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||||
|
}
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule( rec_rule_id, rec_rule);
|
||||||
|
last_rec_rule_id = rec_rule_id;
|
||||||
|
}
|
||||||
|
if (n_opt > 0) {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == '"') { // literal string
|
||||||
|
pos++;
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != '"') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
pos++;
|
||||||
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
|
if (*pos == '^') {
|
||||||
|
pos++;
|
||||||
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
}
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != ']') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
enum llama_gretype type = last_sym_start < rule.size()
|
||||||
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
|
: start_type;
|
||||||
|
|
||||||
|
rule.push_back({type, char_pair.first});
|
||||||
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
if (!pos[1]) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
|
pos = endchar_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
const char * name_end = parse_name(pos);
|
||||||
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
|
pos = parse_space(name_end, is_nested);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
|
} else if (*pos == '(') { // grouping
|
||||||
|
// parse nested alternates into synthesized rule
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||||
|
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
// output reference to synthesized rule
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
if (*pos != ')') {
|
||||||
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '.') { // any char
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '*') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, -1);
|
||||||
|
} else if (*pos == '+') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(1, -1);
|
||||||
|
} else if (*pos == '?') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, 1);
|
||||||
|
} else if (*pos == '{') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (!is_digit_char(*pos)) {
|
||||||
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
|
}
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
|
int max_times = -1;
|
||||||
|
|
||||||
|
if (*pos == '}') {
|
||||||
|
max_times = min_times;
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == ',') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (is_digit_char(*pos)) {
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*pos != '}') {
|
||||||
|
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
|
}
|
||||||
|
handle_repetitions(min_times, max_times);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_rule(const char * src) {
|
||||||
|
const char * name_end = parse_name(src);
|
||||||
|
const char * pos = parse_space(name_end, false);
|
||||||
|
size_t name_len = name_end - src;
|
||||||
|
uint32_t rule_id = get_symbol_id(src, name_len);
|
||||||
|
const std::string name(src, name_len);
|
||||||
|
|
||||||
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||||
|
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 3, true);
|
||||||
|
|
||||||
|
pos = parse_alternates(pos, name, rule_id, false);
|
||||||
|
|
||||||
|
if (*pos == '\r') {
|
||||||
|
pos += pos[1] == '\n' ? 2 : 1;
|
||||||
|
} else if (*pos == '\n') {
|
||||||
|
pos++;
|
||||||
|
} else if (*pos) {
|
||||||
|
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||||
|
}
|
||||||
|
return parse_space(pos, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_grammar_parser::parse(const char * src) {
|
||||||
|
try {
|
||||||
|
const char * pos = parse_space(src, true);
|
||||||
|
while (*pos) {
|
||||||
|
pos = parse_rule(pos);
|
||||||
|
}
|
||||||
|
// Validate the state to ensure that all rules are defined
|
||||||
|
for (const auto & rule : rules) {
|
||||||
|
if (rule.empty()) {
|
||||||
|
throw std::runtime_error("Undefined rule");
|
||||||
|
}
|
||||||
|
for (const auto & elem : rule) {
|
||||||
|
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
||||||
|
// Ensure that the rule at that location exists
|
||||||
|
if (elem.value >= rules.size() || rules[elem.value].empty()) {
|
||||||
|
// Get the name of the rule that is missing
|
||||||
|
for (const auto & kv : symbol_ids) {
|
||||||
|
if (kv.second == elem.value) {
|
||||||
|
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||||
|
rules.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_parser::print(FILE * file) {
|
||||||
|
try {
|
||||||
|
std::map<uint32_t, std::string> symbol_id_names;
|
||||||
|
for (const auto & kv : symbol_ids) {
|
||||||
|
symbol_id_names[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
for (size_t i = 0, end = rules.size(); i < end; i++) {
|
||||||
|
// fprintf(file, "%zu: ", i);
|
||||||
|
// print_rule_binary(file, rules[i]);
|
||||||
|
print_rule(file, uint32_t(i), rules[i], symbol_id_names);
|
||||||
|
// fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stack llama_grammar_parser::c_rules() const {
|
||||||
|
llama_grammar_stack ret;
|
||||||
|
ret.reserve(rules.size());
|
||||||
|
for (const auto & rule : rules) {
|
||||||
|
ret.push_back(rule.data());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns true iff pos points to the end of one of the definitions of a rule
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
@ -89,7 +606,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
|
|||||||
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
||||||
const llama_grammar_element * pos,
|
const llama_grammar_element * pos,
|
||||||
const uint32_t chr) {
|
const uint32_t chr) {
|
||||||
|
|
||||||
bool found = false;
|
bool found = false;
|
||||||
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
||||||
|
|
||||||
@ -225,36 +741,6 @@ static void llama_grammar_advance_stack(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// takes a set of possible pushdown stacks on a grammar, which are required to
|
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
|
||||||
// positions
|
|
||||||
void llama_grammar_accept(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stacks & stacks,
|
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks) {
|
|
||||||
new_stacks.clear();
|
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
|
||||||
if (stack.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
|
||||||
if (match.first) {
|
|
||||||
const llama_grammar_element * pos = match.second;
|
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
|
||||||
new_stack.push_back(pos);
|
|
||||||
}
|
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_grammar_candidates llama_grammar_reject_candidates(
|
static llama_grammar_candidates llama_grammar_reject_candidates(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
@ -270,9 +756,99 @@ static llama_grammar_candidates llama_grammar_reject_candidates(
|
|||||||
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
||||||
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
||||||
}
|
}
|
||||||
|
|
||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_grammar_detect_left_recursion(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
size_t rule_index,
|
||||||
|
std::vector<bool> * rules_visited,
|
||||||
|
std::vector<bool> * rules_in_progress,
|
||||||
|
std::vector<bool> * rules_may_be_empty) {
|
||||||
|
if ((*rules_in_progress)[rule_index]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = true;
|
||||||
|
|
||||||
|
const llama_grammar_rule & rule = rules[rule_index];
|
||||||
|
|
||||||
|
// First check if the rule might produce the empty string. This could be done combined with the second
|
||||||
|
// step but it's more readable as two steps.
|
||||||
|
bool at_rule_start = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
if (at_rule_start) {
|
||||||
|
(*rules_may_be_empty)[rule_index] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
at_rule_start = true;
|
||||||
|
} else {
|
||||||
|
at_rule_start = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
||||||
|
// be empty)
|
||||||
|
bool recurse_into_nonterminal = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
||||||
|
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
recurse_into_nonterminal = true;
|
||||||
|
} else {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = false;
|
||||||
|
(*rules_visited)[rule_index] = true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
||||||
|
return grammar->rules;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
||||||
|
return grammar->stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stacks llama_grammar_accept(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
const uint32_t chr) {
|
||||||
|
llama_grammar_stacks result;
|
||||||
|
result.reserve(stacks.size());
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
|
if (match.first) {
|
||||||
|
const llama_grammar_element * pos = match.second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
new_stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
@ -328,63 +904,10 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_grammar_detect_left_recursion(
|
////////////////////
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
size_t rule_index,
|
|
||||||
std::vector<bool> * rules_visited,
|
|
||||||
std::vector<bool> * rules_in_progress,
|
|
||||||
std::vector<bool> * rules_may_be_empty) {
|
|
||||||
if ((*rules_in_progress)[rule_index]) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
(*rules_in_progress)[rule_index] = true;
|
|
||||||
|
|
||||||
const llama_grammar_rule & rule = rules[rule_index];
|
|
||||||
|
|
||||||
// First check if the rule might produce the empty string. This could be done combined with the second
|
|
||||||
// step but it's more readable as two steps.
|
|
||||||
bool at_rule_start = true;
|
|
||||||
for (size_t i = 0; i < rule.size(); i++) {
|
|
||||||
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
|
||||||
if (at_rule_start) {
|
|
||||||
(*rules_may_be_empty)[rule_index] = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
at_rule_start = true;
|
|
||||||
} else {
|
|
||||||
at_rule_start = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
|
||||||
// be empty)
|
|
||||||
bool recurse_into_nonterminal = true;
|
|
||||||
for (size_t i = 0; i < rule.size(); i++) {
|
|
||||||
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
|
||||||
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
|
||||||
recurse_into_nonterminal = false;
|
|
||||||
}
|
|
||||||
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
|
||||||
recurse_into_nonterminal = true;
|
|
||||||
} else {
|
|
||||||
recurse_into_nonterminal = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(*rules_in_progress)[rule_index] = false;
|
|
||||||
(*rules_visited)[rule_index] = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// grammar - external
|
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index) {
|
size_t start_rule_index) {
|
||||||
@ -438,22 +961,100 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
|
// if there is a grammar, parse it
|
||||||
|
if (!parser.parse(grammar_str)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (parser.rules.empty()) {
|
||||||
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that there is a "root" node.
|
||||||
|
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
|
||||||
|
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
|
||||||
|
|
||||||
|
const size_t n_rules = grammar_rules.size();
|
||||||
|
const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
|
||||||
|
|
||||||
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
|
// copy rule definitions into vectors
|
||||||
|
llama_grammar_rules vec_rules(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
||||||
|
vec_rules[i].push_back(*pos);
|
||||||
|
}
|
||||||
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for left recursion
|
||||||
|
std::vector<bool> rules_visited(n_rules);
|
||||||
|
std::vector<bool> rules_in_progress(n_rules);
|
||||||
|
std::vector<bool> rules_may_be_empty(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
if (rules_visited[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
||||||
|
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over alternates of start rule to build initial stacks
|
||||||
|
llama_grammar_stacks stacks;
|
||||||
|
pos = vec_rules[start_rule_index].data();
|
||||||
|
do {
|
||||||
|
llama_grammar_stack stack;
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
pos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
|
||||||
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
delete grammar;
|
delete grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
|
struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) {
|
||||||
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
|
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
|
||||||
|
|
||||||
// redirect elements in stacks to point to new rules
|
// redirect elements in stacks to point to new rules
|
||||||
for (size_t is = 0; is < result->stacks.size(); is++) {
|
for (size_t is = 0; is < result->stacks.size(); is++) {
|
||||||
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
||||||
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
|
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
||||||
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
|
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
||||||
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
|
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
||||||
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -464,14 +1065,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) {
|
||||||
GGML_ASSERT(grammar);
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
GGML_ASSERT(vocab);
|
|
||||||
|
|
||||||
int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
bool allow_eog = false;
|
bool allow_eog = false;
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
allow_eog = true;
|
allow_eog = true;
|
||||||
break;
|
break;
|
||||||
@ -486,33 +1084,31 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = candidates->data[i].id;
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, id)) {
|
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
|
||||||
if (!allow_eog) {
|
if (!allow_eog) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
} else if (piece.empty() || piece[0] == 0) {
|
} else if (piece.empty() || piece[0] == 0) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
|
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||||
for (const auto & reject : rejects) {
|
for (const auto & reject : rejects) {
|
||||||
candidates->data[reject.index].logit = -INFINITY;
|
candidates->data[reject.index].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, token)) {
|
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -520,20 +1116,17 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
||||||
|
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
llama_grammar_stacks tmp_new_stacks;
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
|
||||||
grammar->stacks = tmp_new_stacks;
|
grammar.stacks = std::move(new_stacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar->partial_utf8 = decoded.second;
|
grammar.partial_utf8 = decoded.second;
|
||||||
GGML_ASSERT(!grammar->stacks.empty());
|
GGML_ASSERT(!grammar.stacks.empty());
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,114 @@
|
|||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
struct llama_sampling;
|
|
||||||
|
// grammar element type
|
||||||
|
enum llama_gretype {
|
||||||
|
// end of rule definition
|
||||||
|
LLAMA_GRETYPE_END = 0,
|
||||||
|
|
||||||
|
// start of alternate definition for rule
|
||||||
|
LLAMA_GRETYPE_ALT = 1,
|
||||||
|
|
||||||
|
// non-terminal element: reference to rule
|
||||||
|
LLAMA_GRETYPE_RULE_REF = 2,
|
||||||
|
|
||||||
|
// terminal element: character (code point)
|
||||||
|
LLAMA_GRETYPE_CHAR = 3,
|
||||||
|
|
||||||
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||||
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||||
|
// be an inclusive range ([a-z])
|
||||||
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||||
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||||
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||||
|
|
||||||
|
// any character (.)
|
||||||
|
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct llama_grammar_element {
|
||||||
|
enum llama_gretype type;
|
||||||
|
uint32_t value; // Unicode code point or rule ID
|
||||||
|
} llama_grammar_element;
|
||||||
|
|
||||||
|
struct llama_partial_utf8 {
|
||||||
|
uint32_t value; // bit value so far (unshifted)
|
||||||
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_candidate {
|
||||||
|
size_t index;
|
||||||
|
const uint32_t * code_points;
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||||
|
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||||
|
|
||||||
|
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||||
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||||
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
|
// positions
|
||||||
|
llama_grammar_stacks llama_grammar_accept(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
uint32_t chr);
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
const llama_grammar_candidates & candidates);
|
||||||
|
|
||||||
|
struct llama_grammar_parser {
|
||||||
|
std::map<std::string, uint32_t> symbol_ids;
|
||||||
|
|
||||||
|
llama_grammar_rules rules;
|
||||||
|
|
||||||
|
llama_grammar_stack c_rules() const;
|
||||||
|
|
||||||
|
uint32_t get_symbol_id(const char * src, size_t len);
|
||||||
|
uint32_t generate_symbol_id(const std::string & base_name);
|
||||||
|
|
||||||
|
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested);
|
||||||
|
|
||||||
|
const char * parse_sequence(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
llama_grammar_rule & rule,
|
||||||
|
bool is_nested);
|
||||||
|
|
||||||
|
const char * parse_rule(const char * src);
|
||||||
|
|
||||||
|
bool parse(const char * src);
|
||||||
|
void print(FILE * file);
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
const llama_grammar_rules rules;
|
// note: allow null vocab for testing (not great)
|
||||||
|
const llama_vocab * vocab;
|
||||||
|
|
||||||
|
const llama_grammar_rules rules; // TODO: shared ptr
|
||||||
llama_grammar_stacks stacks;
|
llama_grammar_stacks stacks;
|
||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
@ -17,23 +120,24 @@ struct llama_grammar {
|
|||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// note: needed for tests (not great)
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index);
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar);
|
||||||
|
|
||||||
void llama_grammar_sample_impl(
|
// TODO: move the API below as member functions of llama_grammar
|
||||||
const struct llama_grammar * grammar,
|
void llama_grammar_apply_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_grammar & grammar,
|
||||||
const struct llama_sampling * smpl,
|
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
void llama_grammar_accept_token_impl(
|
void llama_grammar_accept_impl(
|
||||||
struct llama_grammar * grammar,
|
struct llama_grammar & grammar,
|
||||||
const struct llama_vocab * vocab,
|
|
||||||
const struct llama_sampling * smpl,
|
|
||||||
llama_token token);
|
llama_token token);
|
||||||
|
116
src/llama-impl.h
116
src/llama-impl.h
@ -1,8 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define LLAMA_API_INTERNAL
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#ifdef __MINGW32__
|
#ifdef __MINGW32__
|
||||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
@ -45,3 +48,114 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
|||||||
builder.append(s, last_pos, std::string::npos);
|
builder.append(s, last_pos, std::string::npos);
|
||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||||
|
struct llama_context * ctx
|
||||||
|
);
|
||||||
|
|
||||||
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
|
template<typename T>
|
||||||
|
struct ring_buffer {
|
||||||
|
ring_buffer() {}
|
||||||
|
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||||
|
|
||||||
|
T & front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & front() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
T & back() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & back() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const T & value) {
|
||||||
|
if (sz == capacity) {
|
||||||
|
// advance the start when buffer is full
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
} else {
|
||||||
|
sz++;
|
||||||
|
}
|
||||||
|
data[pos] = value;
|
||||||
|
pos = (pos + 1) % capacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
T pop_front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
T value = data[first];
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
sz--;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
//T & operator[](size_t i) {
|
||||||
|
// if (i >= sz) {
|
||||||
|
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
// }
|
||||||
|
// return data[(first + i) % capacity];
|
||||||
|
//}
|
||||||
|
|
||||||
|
//const T & at(size_t i) const {
|
||||||
|
// if (i >= sz) {
|
||||||
|
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
// }
|
||||||
|
// return data[(first + i) % capacity];
|
||||||
|
//}
|
||||||
|
|
||||||
|
const T & rat(size_t i) const {
|
||||||
|
if (i >= sz) {
|
||||||
|
throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
}
|
||||||
|
return data[(first + sz - i - 1) % capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<T> to_vector() const {
|
||||||
|
std::vector<T> result;
|
||||||
|
result.reserve(sz);
|
||||||
|
for (size_t i = 0; i < sz; i++) {
|
||||||
|
result.push_back(data[(first + i) % capacity]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
// here only reset the status of the buffer
|
||||||
|
sz = 0;
|
||||||
|
first = 0;
|
||||||
|
pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return sz == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t capacity = 0;
|
||||||
|
size_t sz = 0;
|
||||||
|
size_t first = 0;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::vector<T> data;
|
||||||
|
};
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
|
#include "llama-vocab.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@ -21,18 +24,104 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
|
llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling::~llama_sampling() {
|
||||||
|
if (grammar) {
|
||||||
|
llama_grammar_free_impl(grammar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
|
||||||
|
auto * result = new llama_sampling(vocab);
|
||||||
|
|
||||||
|
result->params = params;
|
||||||
|
|
||||||
|
result->prev = ring_buffer<llama_token>(params.n_prev);
|
||||||
|
|
||||||
|
for (int i = 0; i < params.n_samplers; ++i) {
|
||||||
|
result->samplers.push_back(params.samplers[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_set_rng_seed_impl(*result, params.seed);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
||||||
|
delete sampling;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
||||||
|
auto * result = new llama_sampling(smpl.vocab);
|
||||||
|
|
||||||
|
result->params = smpl.params;
|
||||||
|
|
||||||
|
result->grammar_str = smpl.grammar_str;
|
||||||
|
result->grammar_root = smpl.grammar_root;
|
||||||
|
|
||||||
|
result->logit_bias = smpl.logit_bias;
|
||||||
|
|
||||||
|
if (smpl.grammar) {
|
||||||
|
result->grammar = llama_grammar_cp_impl(*smpl.grammar);
|
||||||
|
}
|
||||||
|
|
||||||
|
result->rng = smpl.rng;
|
||||||
|
result->prev = smpl.prev;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_reset_impl(struct llama_sampling & smpl) {
|
||||||
|
if (smpl.grammar) {
|
||||||
|
llama_grammar_free_impl(smpl.grammar);
|
||||||
|
smpl.grammar = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!smpl.grammar_str.empty()) {
|
||||||
|
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl.prev.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
|
||||||
if (seed == LLAMA_DEFAULT_SEED) {
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
seed = time(NULL);
|
seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
smpl->rng.seed(seed);
|
smpl.rng.seed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
|
||||||
GGML_ASSERT(candidates->size > 0);
|
if (smpl.grammar) {
|
||||||
|
llama_grammar_free_impl(smpl.grammar);
|
||||||
|
smpl.grammar = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||||
|
smpl.grammar_str = grammar_str;
|
||||||
|
smpl.grammar_root = grammar_root;
|
||||||
|
|
||||||
|
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root);
|
||||||
|
} else {
|
||||||
|
smpl.grammar_str.clear();
|
||||||
|
smpl.grammar_root.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||||
|
smpl.logit_bias.clear();
|
||||||
|
smpl.logit_bias.reserve(n_logit_bias);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_logit_bias; ++i) {
|
||||||
|
smpl.logit_bias.push_back(logit_bias[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_softmax_impl(llama_token_data_array * candidates) {
|
||||||
|
GGML_ASSERT(candidates->size > 0);
|
||||||
|
|
||||||
// Sort the logits in descending order
|
// Sort the logits in descending order
|
||||||
if (!candidates->sorted) {
|
if (!candidates->sorted) {
|
||||||
@ -44,28 +133,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
|
|
||||||
float max_l = candidates->data[0].logit;
|
float max_l = candidates->data[0].logit;
|
||||||
float cum_sum = 0.0f;
|
float cum_sum = 0.0f;
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
float p = expf(candidates->data[i].logit - max_l);
|
float p = expf(candidates->data[i].logit - max_l);
|
||||||
candidates->data[i].p = p;
|
candidates->data[i].p = p;
|
||||||
cum_sum += p;
|
cum_sum += p;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
candidates->data[i].p /= cum_sum;
|
candidates->data[i].p /= cum_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||||
// if (k >= (int32_t)candidates->size) {
|
// if (k >= (int32_t)candidates->size) {
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
if (k <= 0) {
|
if (k <= 0) {
|
||||||
k = candidates->size;
|
k = candidates->size;
|
||||||
}
|
}
|
||||||
@ -101,10 +186,12 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||||||
int ib = nbuckets - 1;
|
int ib = nbuckets - 1;
|
||||||
for ( ; ib >= 0; --ib) {
|
for ( ; ib >= 0; --ib) {
|
||||||
nhave += histo[ib];
|
nhave += histo[ib];
|
||||||
if (nhave >= k) break;
|
if (nhave >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::vector<llama_token_data> tmp_tokens(nhave);
|
std::vector<llama_token_data> tmp_tokens(nhave);
|
||||||
auto ptr = tmp_tokens.data();
|
auto * ptr = tmp_tokens.data();
|
||||||
std::vector<llama_token_data*> bucket_ptrs;
|
std::vector<llama_token_data*> bucket_ptrs;
|
||||||
bucket_ptrs.reserve(nbuckets - ib);
|
bucket_ptrs.reserve(nbuckets - ib);
|
||||||
for (int j = nbuckets - 1; j >= ib; --j) {
|
for (int j = nbuckets - 1; j >= ib; --j) {
|
||||||
@ -133,20 +220,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||||||
candidates->sorted = true;
|
candidates->sorted = true;
|
||||||
}
|
}
|
||||||
candidates->size = k;
|
candidates->size = k;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
if (p >= 1.0f) {
|
if (p >= 1.0f) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sample_softmax_impl(smpl, candidates);
|
llama_sampling_softmax_impl(candidates);
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Compute the cumulative probabilities
|
// Compute the cumulative probabilities
|
||||||
float cum_sum = 0.0f;
|
float cum_sum = 0.0f;
|
||||||
@ -165,19 +246,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||||||
|
|
||||||
// Resize the output vector to keep only the top-p tokens
|
// Resize the output vector to keep only the top-p tokens
|
||||||
candidates->size = last_idx;
|
candidates->size = last_idx;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
if (p <= 0.0f || !candidates->size) {
|
if (p <= 0.0f || !candidates->size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
bool min_p_applied = false;
|
bool min_p_applied = false;
|
||||||
|
|
||||||
// if the candidates aren't sorted, try the unsorted implementation first
|
// if the candidates aren't sorted, try the unsorted implementation first
|
||||||
@ -226,19 +301,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||||||
// Resize the output vector to keep only the matching tokens
|
// Resize the output vector to keep only the matching tokens
|
||||||
candidates->size = i;
|
candidates->size = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||||
if (z >= 1.0f || candidates->size <= 2) {
|
if (z >= 1.0f || candidates->size <= 2) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
llama_sampling_softmax_impl(candidates);
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Compute the first and second derivatives
|
// Compute the first and second derivatives
|
||||||
std::vector<float> first_derivatives(candidates->size - 1);
|
std::vector<float> first_derivatives(candidates->size - 1);
|
||||||
@ -285,13 +355,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|||||||
|
|
||||||
// Resize the output vector to keep only the tokens above the tail location
|
// Resize the output vector to keep only the tokens above the tail location
|
||||||
candidates->size = last_idx;
|
candidates->size = last_idx;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
// Reference implementation:
|
// Reference implementation:
|
||||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||||
if (p >= 1.0f) {
|
if (p >= 1.0f) {
|
||||||
@ -299,9 +365,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute the softmax of logits and calculate entropy
|
// Compute the softmax of logits and calculate entropy
|
||||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
llama_sampling_softmax_impl(candidates);
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
float entropy = 0.0f;
|
float entropy = 0.0f;
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
@ -349,15 +413,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||||
candidates->size = new_candidates.size();
|
candidates->size = new_candidates.size();
|
||||||
candidates->sorted = false;
|
candidates->sorted = false;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// no need to do anything if there is only one (or zero) candidates
|
// no need to do anything if there is only one (or zero) candidates
|
||||||
if(candidates->size <= 1) {
|
if(candidates->size <= 1) {
|
||||||
return;
|
return;
|
||||||
@ -366,7 +424,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
// Calculate maximum possible entropy
|
// Calculate maximum possible entropy
|
||||||
float max_entropy = -logf(1.0f / candidates->size);
|
float max_entropy = -logf(1.0f / candidates->size);
|
||||||
|
|
||||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
llama_sampling_softmax_impl(candidates);
|
||||||
|
|
||||||
// Calculate entropy of the softmax probabilities
|
// Calculate entropy of the softmax probabilities
|
||||||
float entropy = 0.0f;
|
float entropy = 0.0f;
|
||||||
@ -398,13 +456,15 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||||
double max_l_double = candidates->data[0].logit;
|
const double max_l_double = candidates->data[0].logit;
|
||||||
|
|
||||||
double cum_sum_double = 0.0;
|
double cum_sum_double = 0.0;
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
double p = exp(candidates->data[i].logit - max_l_double);
|
double p = exp(candidates->data[i].logit - max_l_double);
|
||||||
candidates->data[i].p = p; // Store the scaled probability
|
candidates->data[i].p = p; // Store the scaled probability
|
||||||
cum_sum_double += p;
|
cum_sum_double += p;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
||||||
}
|
}
|
||||||
@ -416,44 +476,24 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||||||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
|
void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
candidates->data[i].logit /= temp;
|
candidates->data[i].logit /= temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_repetition_penalties_impl(
|
void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) {
|
||||||
struct llama_sampling * smpl,
|
llama_grammar_apply_impl(grammar, candidates);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_penalties_impl(
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
const llama_token * last_tokens,
|
const llama_token_cnt & token_count,
|
||||||
size_t penalty_last_n,
|
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present) {
|
float penalty_present) {
|
||||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
|
||||||
std::unordered_map<llama_token, int> token_count;
|
|
||||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
|
||||||
token_count[last_tokens[i]]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the candidates
|
// Apply frequency and presence penalties to the candidates
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const auto token_iter = token_count.find(candidates->data[i].id);
|
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||||
@ -475,43 +515,10 @@ void llama_sample_repetition_penalties_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
candidates->sorted = false;
|
candidates->sorted = false;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_apply_guidance_impl(
|
llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
|
||||||
struct llama_sampling * smpl,
|
llama_sampling_softmax_impl(candidates);
|
||||||
float * logits,
|
|
||||||
float * logits_guidance,
|
|
||||||
float scale) {
|
|
||||||
GGML_ASSERT(smpl);
|
|
||||||
|
|
||||||
const auto t_start_sample_us = ggml_time_us();
|
|
||||||
const auto n_vocab = smpl->n_vocab;
|
|
||||||
|
|
||||||
llama_log_softmax(logits, n_vocab);
|
|
||||||
llama_log_softmax(logits_guidance, n_vocab);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
|
||||||
auto & l = logits[i];
|
|
||||||
const auto & g = logits_guidance[i];
|
|
||||||
|
|
||||||
l = scale * (l - g) + g;
|
|
||||||
}
|
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
|
||||||
GGML_ASSERT(smpl);
|
|
||||||
|
|
||||||
const int32_t n_vocab = float(smpl->n_vocab);
|
|
||||||
|
|
||||||
int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
|
||||||
|
|
||||||
// Estimate s_hat using the most probable m tokens
|
// Estimate s_hat using the most probable m tokens
|
||||||
float s_hat = 0.0;
|
float s_hat = 0.0;
|
||||||
@ -527,13 +534,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||||||
|
|
||||||
// Compute k from the estimated s_hat and target surprise value
|
// Compute k from the estimated s_hat and target surprise value
|
||||||
float epsilon_hat = s_hat - 1;
|
float epsilon_hat = s_hat - 1;
|
||||||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||||
|
|
||||||
// Sample the next word X using top-k sampling
|
// Sample the next word X using top-k sampling
|
||||||
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
llama_sampling_top_k_impl(candidates, int(k), 1);
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
|
||||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
|
||||||
t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Compute error as the difference between observed surprise and target surprise value
|
// Compute error as the difference between observed surprise and target surprise value
|
||||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
@ -543,93 +548,88 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||||||
float e = observed_surprise - tau;
|
float e = observed_surprise - tau;
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
// Update mu using the learning rate and error
|
||||||
*mu = *mu - eta * e;
|
mu = mu - eta * e;
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
return X;
|
return X;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
|
||||||
int64_t t_start_sample_us;
|
llama_sampling_softmax_impl(candidates);
|
||||||
t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
llama_sample_softmax_impl(smpl, candidates);
|
|
||||||
|
|
||||||
// Truncate the words with surprise values greater than mu
|
// Truncate the words with surprise values greater than mu
|
||||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
return -log2f(candidate.p) > *mu;
|
return -log2f(candidate.p) > mu;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
if (candidates->size == 0) {
|
if (candidates->size == 0) {
|
||||||
candidates->size = 1;
|
candidates->size = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize the probabilities of the remaining words
|
// Normalize the probabilities of the remaining words
|
||||||
llama_sample_softmax_impl(smpl, candidates);
|
llama_sampling_softmax_impl(candidates);
|
||||||
|
|
||||||
// Sample the next word X from the remaining words
|
// Sample the next word X from the remaining words
|
||||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
|
||||||
t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Compute error as the difference between observed surprise and target surprise value
|
// Compute error as the difference between observed surprise and target surprise value
|
||||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
return candidate.id == X;
|
return candidate.id == X;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||||
float e = observed_surprise - tau;
|
float e = observed_surprise - tau;
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
// Update mu using the learning rate and error
|
||||||
*mu = *mu - eta * e;
|
mu = mu - eta * e;
|
||||||
|
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
return X;
|
return X;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
// Find max element
|
// Find max element
|
||||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
return a.logit < b.logit;
|
return a.logit < b.logit;
|
||||||
});
|
});
|
||||||
|
|
||||||
llama_token result = max_iter->id;
|
llama_token result = max_iter->id;
|
||||||
if (smpl) {
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
smpl->n_sample++;
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||||
GGML_ASSERT(smpl);
|
llama_sampling_softmax_impl(candidates);
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
|
||||||
|
|
||||||
std::vector<float> probs;
|
std::vector<float> probs;
|
||||||
probs.reserve(candidates->size);
|
probs.reserve(candidates->size);
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
probs.push_back(candidates->data[i].p);
|
probs.push_back(candidates->data[i].p);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
int idx = dist(rng);
|
|
||||||
|
|
||||||
|
const int idx = dist(rng);
|
||||||
llama_token result = candidates->data[idx].id;
|
llama_token result = candidates->data[idx].id;
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
smpl->n_sample++;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) {
|
||||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
smpl.prev.push_back(token);
|
||||||
|
|
||||||
|
if (apply_grammar && smpl.grammar) {
|
||||||
|
llama_grammar_accept_impl(*smpl.grammar, token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) {
|
||||||
|
if (ith < 0 || ith >= (int) smpl.prev.size()) {
|
||||||
|
return LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return smpl.prev.rat(ith);
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) {
|
||||||
|
return smpl.prev.size();
|
||||||
}
|
}
|
||||||
|
@ -1,56 +1,107 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
struct llama_vocab;
|
||||||
|
struct llama_grammar;
|
||||||
|
|
||||||
|
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||||
|
|
||||||
struct llama_sampling {
|
struct llama_sampling {
|
||||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
llama_sampling(const struct llama_vocab & vocab);
|
||||||
|
~llama_sampling();
|
||||||
|
|
||||||
|
llama_sampling_params params;
|
||||||
|
|
||||||
|
std::string grammar_str;
|
||||||
|
std::string grammar_root;
|
||||||
|
|
||||||
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
||||||
|
// state
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
int32_t n_vocab = 0;
|
const struct llama_vocab & vocab;
|
||||||
|
|
||||||
|
std::vector<llama_sampler_type> samplers;
|
||||||
|
|
||||||
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
struct llama_grammar * grammar = nullptr;
|
||||||
|
|
||||||
|
// mirostat sampler state
|
||||||
|
float mirostat_mu;
|
||||||
|
|
||||||
mutable int64_t t_sample_us = 0;
|
mutable int64_t t_sample_us = 0;
|
||||||
mutable int32_t n_sample = 0;
|
mutable int64_t t_grammar_us = 0;
|
||||||
|
mutable int64_t t_accept_us = 0;
|
||||||
|
|
||||||
void reset_timings() const {
|
mutable int32_t n_sample = 0;
|
||||||
t_sample_us = 0;
|
mutable int32_t n_grammar = 0;
|
||||||
n_sample = 0;
|
mutable int32_t n_accept = 0;
|
||||||
}
|
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
|
llama_token_data_array cur_p;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
|
|
||||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params);
|
||||||
|
|
||||||
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
void llama_sampling_free_impl(struct llama_sampling * sampling);
|
||||||
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
|
||||||
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
|
||||||
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
|
||||||
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
|
||||||
|
|
||||||
void llama_sample_repetition_penalties_impl(
|
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl);
|
||||||
struct llama_sampling * smpl,
|
|
||||||
|
void llama_sampling_reset_impl(struct llama_sampling & smpl);
|
||||||
|
|
||||||
|
// TODO: move the API below as member functions of llama_sampling
|
||||||
|
void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed);
|
||||||
|
void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
|
||||||
|
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
void llama_sampling_softmax_impl (struct llama_token_data_array * candidates);
|
||||||
|
void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||||
|
void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep);
|
||||||
|
void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||||
|
void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp);
|
||||||
|
void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar);
|
||||||
|
|
||||||
|
void llama_sampling_penalties_impl(
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
const llama_token * last_tokens,
|
const llama_token_cnt & token_count,
|
||||||
size_t penalty_last_n,
|
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present);
|
float penalty_present);
|
||||||
|
|
||||||
void llama_sample_apply_guidance_impl(
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
struct llama_sampling * smpl,
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
float * logits,
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
float * logits_guidance,
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
float scale);
|
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||||
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
|
llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
|
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates);
|
||||||
|
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
|
||||||
|
|
||||||
|
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar);
|
||||||
|
|
||||||
|
llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith);
|
||||||
|
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl);
|
||||||
|
@ -18,6 +18,8 @@ struct llama_vocab {
|
|||||||
tattr attr;
|
tattr attr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
||||||
|
|
||||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
|
||||||
@ -62,8 +64,6 @@ struct llama_vocab {
|
|||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
@ -76,6 +76,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special = false);
|
bool parse_special = false);
|
||||||
|
|
||||||
|
// TODO: move the API below as member functions of llama_vocab
|
||||||
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
||||||
|
|
||||||
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
594
src/llama.cpp
594
src/llama.cpp
@ -1,6 +1,5 @@
|
|||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-grammar.h"
|
|
||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
@ -148,6 +147,19 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct time_meas {
|
||||||
|
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
~time_meas() {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
||||||
static std::string format(const char * fmt, ...) {
|
static std::string format(const char * fmt, ...) {
|
||||||
va_list ap;
|
va_list ap;
|
||||||
@ -3179,7 +3191,6 @@ struct llama_sbatch {
|
|||||||
struct llama_context {
|
struct llama_context {
|
||||||
llama_context(const llama_model & model)
|
llama_context(const llama_model & model)
|
||||||
: model(model)
|
: model(model)
|
||||||
, sampling(llama_n_vocab(&model))
|
|
||||||
, t_start_us(model.t_start_us)
|
, t_start_us(model.t_start_us)
|
||||||
, t_load_us(model.t_load_us) {}
|
, t_load_us(model.t_load_us) {}
|
||||||
|
|
||||||
@ -3196,7 +3207,6 @@ struct llama_context {
|
|||||||
const struct llama_model & model;
|
const struct llama_model & model;
|
||||||
|
|
||||||
struct llama_cparams cparams;
|
struct llama_cparams cparams;
|
||||||
struct llama_sampling sampling;
|
|
||||||
struct llama_sbatch sbatch;
|
struct llama_sbatch sbatch;
|
||||||
struct llama_kv_cache kv_self;
|
struct llama_kv_cache kv_self;
|
||||||
struct llama_control_vector cvec;
|
struct llama_control_vector cvec;
|
||||||
@ -3217,16 +3227,16 @@ struct llama_context {
|
|||||||
|
|
||||||
bool has_evaluated_once = false;
|
bool has_evaluated_once = false;
|
||||||
|
|
||||||
int64_t t_start_us;
|
mutable int64_t t_start_us;
|
||||||
int64_t t_load_us;
|
mutable int64_t t_load_us;
|
||||||
int64_t t_p_eval_us = 0;
|
mutable int64_t t_p_eval_us = 0;
|
||||||
int64_t t_eval_us = 0;
|
mutable int64_t t_eval_us = 0;
|
||||||
|
|
||||||
int64_t t_compute_start_us = 0;
|
mutable int64_t t_compute_start_us = 0;
|
||||||
int64_t n_queued_tokens = 0;
|
mutable int64_t n_queued_tokens = 0;
|
||||||
|
|
||||||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||||
int32_t n_eval = 0; // number of eval calls
|
mutable int32_t n_eval = 0; // number of eval calls
|
||||||
|
|
||||||
// host buffer for the model output (logits and embeddings)
|
// host buffer for the model output (logits and embeddings)
|
||||||
ggml_backend_buffer_t buf_output = nullptr;
|
ggml_backend_buffer_t buf_output = nullptr;
|
||||||
@ -6239,6 +6249,7 @@ static void llm_load_vocab(
|
|||||||
|
|
||||||
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||||
|
|
||||||
|
vocab.n_vocab = n_vocab;
|
||||||
vocab.id_to_token.resize(n_vocab);
|
vocab.id_to_token.resize(n_vocab);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||||
@ -17863,7 +17874,6 @@ struct llama_model_params llama_model_default_params() {
|
|||||||
|
|
||||||
struct llama_context_params llama_context_default_params() {
|
struct llama_context_params llama_context_default_params() {
|
||||||
struct llama_context_params result = {
|
struct llama_context_params result = {
|
||||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
|
||||||
/*.n_ctx =*/ 512,
|
/*.n_ctx =*/ 512,
|
||||||
/*.n_batch =*/ 2048,
|
/*.n_batch =*/ 2048,
|
||||||
/*.n_ubatch =*/ 512,
|
/*.n_ubatch =*/ 512,
|
||||||
@ -17896,6 +17906,36 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampling_params llama_sampling_default_params() {
|
||||||
|
struct llama_sampling_params result = {
|
||||||
|
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||||
|
/*.n_prev =*/ 64,
|
||||||
|
/*.n_probs =*/ 0,
|
||||||
|
/*.min_keep =*/ 0,
|
||||||
|
/*.top_k =*/ 40,
|
||||||
|
/*.top_p =*/ 0.95f,
|
||||||
|
/*.min_p =*/ 0.05f,
|
||||||
|
/*.tfs_z =*/ 1.00f,
|
||||||
|
/*.typ_p =*/ 1.00f,
|
||||||
|
/*.temp =*/ 0.80f,
|
||||||
|
/*.dynatemp_range =*/ 0.00f,
|
||||||
|
/*.dynatemp_exponent =*/ 1.00f,
|
||||||
|
/*.penalty_last_n =*/ 64,
|
||||||
|
/*.penalty_repeat =*/ 1.00f,
|
||||||
|
/*.penalty_freq =*/ 0.00f,
|
||||||
|
/*.penalty_present =*/ 0.00f,
|
||||||
|
/*.mirostat =*/ 0,
|
||||||
|
/*.mirostat_tau =*/ 5.00f,
|
||||||
|
/*.mirostat_eta =*/ 0.10f,
|
||||||
|
/*.n_samplers =*/ 3,
|
||||||
|
/*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, },
|
||||||
|
/*.penalize_nl =*/ false,
|
||||||
|
/*.ignore_eos =*/ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||||
struct llama_model_quantize_params result = {
|
struct llama_model_quantize_params result = {
|
||||||
/*.nthread =*/ 0,
|
/*.nthread =*/ 0,
|
||||||
@ -18149,10 +18189,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||||
@ -18163,8 +18199,8 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
ctx->abort_callback = params.abort_callback;
|
ctx->abort_callback = params.abort_callback;
|
||||||
ctx->abort_callback_data = params.abort_callback_data;
|
ctx->abort_callback_data = params.abort_callback_data;
|
||||||
|
|
||||||
ctx->sampling.rng = std::mt19937(params.seed);
|
|
||||||
ctx->logits_all = params.logits_all;
|
ctx->logits_all = params.logits_all;
|
||||||
|
|
||||||
// build worst-case graph for encoder if a model contains encoder
|
// build worst-case graph for encoder if a model contains encoder
|
||||||
ctx->is_encoding = llama_model_has_encoder(model);
|
ctx->is_encoding = llama_model_has_encoder(model);
|
||||||
|
|
||||||
@ -18443,14 +18479,6 @@ void llama_free(struct llama_context * ctx) {
|
|||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
|
||||||
return &ctx->model;
|
|
||||||
}
|
|
||||||
|
|
||||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
|
|
||||||
return &ctx->model.vocab;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t llama_n_ctx(const struct llama_context * ctx) {
|
uint32_t llama_n_ctx(const struct llama_context * ctx) {
|
||||||
return ctx->cparams.n_ctx;
|
return ctx->cparams.n_ctx;
|
||||||
}
|
}
|
||||||
@ -18471,6 +18499,30 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
|||||||
return model->vocab.type;
|
return model->vocab.type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t llama_n_vocab(const struct llama_model * model) {
|
||||||
|
return model->hparams.n_vocab;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_n_ctx_train(const struct llama_model * model) {
|
||||||
|
return model->hparams.n_ctx_train;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_n_embd(const struct llama_model * model) {
|
||||||
|
return model->hparams.n_embd;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_n_layer(const struct llama_model * model) {
|
||||||
|
return model->hparams.n_layer;
|
||||||
|
}
|
||||||
|
|
||||||
|
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||||
|
return &ctx->model;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
||||||
|
return ctx->cparams.pooling_type;
|
||||||
|
}
|
||||||
|
|
||||||
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
switch (model->arch) {
|
switch (model->arch) {
|
||||||
// these models do not use RoPE
|
// these models do not use RoPE
|
||||||
@ -18534,26 +18586,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||||||
return LLAMA_ROPE_TYPE_NONE;
|
return LLAMA_ROPE_TYPE_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
|
||||||
return ctx->cparams.pooling_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_n_vocab(const struct llama_model * model) {
|
|
||||||
return model->hparams.n_vocab;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_n_ctx_train(const struct llama_model * model) {
|
|
||||||
return model->hparams.n_ctx_train;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_n_embd(const struct llama_model * model) {
|
|
||||||
return model->hparams.n_embd;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_n_layer(const struct llama_model * model) {
|
|
||||||
return model->hparams.n_layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
float llama_rope_freq_scale_train(const struct llama_model * model) {
|
float llama_rope_freq_scale_train(const struct llama_model * model) {
|
||||||
return model->hparams.rope_freq_scale_train;
|
return model->hparams.rope_freq_scale_train;
|
||||||
}
|
}
|
||||||
@ -18970,14 +19002,14 @@ struct llama_data_write {
|
|||||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_rng(const std::mt19937 & rng) {
|
//void write_rng(const std::mt19937 & rng) {
|
||||||
std::ostringstream rng_ss;
|
// std::ostringstream rng_ss;
|
||||||
rng_ss << rng;
|
// rng_ss << rng;
|
||||||
|
|
||||||
const std::string & rng_str = rng_ss.str();
|
// const std::string & rng_str = rng_ss.str();
|
||||||
|
|
||||||
write_string(rng_str);
|
// write_string(rng_str);
|
||||||
}
|
//}
|
||||||
|
|
||||||
void write_output_ids(struct llama_context * ctx) {
|
void write_output_ids(struct llama_context * ctx) {
|
||||||
llama_output_reorder(ctx);
|
llama_output_reorder(ctx);
|
||||||
@ -19197,17 +19229,17 @@ struct llama_data_read {
|
|||||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_rng(std::mt19937 & rng) {
|
//void read_rng(std::mt19937 & rng) {
|
||||||
std::string rng_str;
|
// std::string rng_str;
|
||||||
read_string(rng_str);
|
// read_string(rng_str);
|
||||||
|
|
||||||
std::istringstream rng_ss(rng_str);
|
// std::istringstream rng_ss(rng_str);
|
||||||
rng_ss >> rng;
|
// rng_ss >> rng;
|
||||||
|
|
||||||
if (rng_ss.fail()) {
|
// if (rng_ss.fail()) {
|
||||||
throw std::runtime_error("failed to load RNG state");
|
// throw std::runtime_error("failed to load RNG state");
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
void read_output_ids(struct llama_context * ctx) {
|
void read_output_ids(struct llama_context * ctx) {
|
||||||
std::vector<int32_t> output_pos;
|
std::vector<int32_t> output_pos;
|
||||||
@ -19637,8 +19669,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
|
|||||||
|
|
||||||
data_ctx.write_model_info(ctx);
|
data_ctx.write_model_info(ctx);
|
||||||
|
|
||||||
data_ctx.write_rng(ctx->sampling.rng);
|
|
||||||
|
|
||||||
// copy outputs
|
// copy outputs
|
||||||
data_ctx.write_output_ids(ctx);
|
data_ctx.write_output_ids(ctx);
|
||||||
data_ctx.write_logits(ctx);
|
data_ctx.write_logits(ctx);
|
||||||
@ -19676,9 +19706,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
|
|||||||
|
|
||||||
data_ctx.read_model_info(ctx);
|
data_ctx.read_model_info(ctx);
|
||||||
|
|
||||||
// set rng
|
|
||||||
data_ctx.read_rng(ctx->sampling.rng);
|
|
||||||
|
|
||||||
// set outputs
|
// set outputs
|
||||||
data_ctx.read_output_ids(ctx);
|
data_ctx.read_output_ids(ctx);
|
||||||
data_ctx.read_logits(ctx);
|
data_ctx.read_logits(ctx);
|
||||||
@ -20081,8 +20108,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|||||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
#endif
|
#else
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20130,8 +20158,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
|||||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
#endif
|
#else
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20564,125 +20593,350 @@ int32_t llama_chat_apply_template(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// grammar
|
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init(
|
|
||||||
const llama_grammar_element ** rules,
|
|
||||||
size_t n_rules,
|
|
||||||
size_t start_rule_index) {
|
|
||||||
return llama_grammar_init_impl(rules, n_rules, start_rule_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_grammar_free(struct llama_grammar * grammar) {
|
|
||||||
llama_grammar_free_impl(grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
|
|
||||||
return llama_grammar_copy_impl(grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_grammar_sample(
|
|
||||||
const struct llama_grammar * grammar,
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates) {
|
|
||||||
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sample_grammar(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const struct llama_grammar * grammar) {
|
|
||||||
llama_grammar_sample(grammar, ctx, candidates);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_grammar_accept_token(
|
|
||||||
struct llama_grammar * grammar,
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token token) {
|
|
||||||
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
|
||||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) {
|
||||||
llama_set_rng_seed_impl(&ctx->sampling, seed);
|
return llama_sampling_init_impl(model->vocab, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
void llama_sampling_free(struct llama_sampling * smpl) {
|
||||||
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
if (smpl == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_free_impl(smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
|
||||||
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
|
return llama_sampling_cp_impl(*smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_reset(struct llama_sampling * smpl) {
|
||||||
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
llama_sampling_reset_impl(*smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
|
||||||
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||||
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
|
llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) {
|
||||||
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
const int n_vocab = smpl->vocab.n_vocab;
|
||||||
|
|
||||||
|
smpl->cur.resize(n_vocab);
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & lb : smpl->logit_bias) {
|
||||||
|
smpl->cur[lb.token].logit += lb.bias;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl->params.ignore_eos) {
|
||||||
|
smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false };
|
||||||
|
|
||||||
|
// apply penalties
|
||||||
|
{
|
||||||
|
const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit;
|
||||||
|
|
||||||
|
llama_sampling_penalties(smpl, &smpl->cur_p);
|
||||||
|
|
||||||
|
if (!smpl->params.penalize_nl) {
|
||||||
|
for (size_t idx = 0; idx < smpl->cur_p.size; idx++) {
|
||||||
|
if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) {
|
||||||
|
smpl->cur_p.data[idx].logit = nl_logit;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) {
|
||||||
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
|
return &smpl->cur_p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_softmax_impl(candidates);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_repetition_penalties(
|
void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
struct llama_context * ctx,
|
time_meas tm(smpl->t_sample_us);
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
if (candidates == nullptr) {
|
||||||
size_t penalty_last_n,
|
candidates = &smpl->cur_p;
|
||||||
float penalty_repeat,
|
}
|
||||||
float penalty_freq,
|
|
||||||
float penalty_present) {
|
llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
|
||||||
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_apply_guidance(
|
void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
struct llama_context * ctx,
|
time_meas tm(smpl->t_sample_us);
|
||||||
float * logits,
|
|
||||||
float * logits_guidance,
|
if (candidates == nullptr) {
|
||||||
float scale) {
|
candidates = &smpl->cur_p;
|
||||||
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
|
}
|
||||||
|
|
||||||
|
llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl->params.dynatemp_range > 0) {
|
||||||
|
const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
|
||||||
|
const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
|
||||||
|
|
||||||
|
llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent);
|
||||||
|
} else {
|
||||||
|
llama_sampling_temp_impl(candidates, smpl->params.temp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
|
time_meas tm(smpl->t_grammar_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl->grammar) {
|
||||||
|
llama_sampling_grammar_impl(candidates, *smpl->grammar);
|
||||||
|
|
||||||
|
smpl->n_grammar++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampling_penalties(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());
|
||||||
|
|
||||||
|
const float penalty_repeat = smpl->params.penalty_repeat;
|
||||||
|
const float penalty_freq = smpl->params.penalty_freq;
|
||||||
|
const float penalty_present = smpl->params.penalty_present;
|
||||||
|
|
||||||
|
if ((penalty_last_n == 0) ||
|
||||||
|
(penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
|
// TODO: move to sampling state and avoid reallocation
|
||||||
|
llama_token_cnt token_count;
|
||||||
|
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||||
|
token_count[smpl->prev.rat(i)]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto type = smpl->params.mirostat;
|
||||||
|
|
||||||
|
llama_token res;
|
||||||
|
|
||||||
|
if (type == 1) {
|
||||||
|
res = llama_sampling_sample_mirostat_impl(candidates,
|
||||||
|
smpl->rng,
|
||||||
|
smpl->params.mirostat_tau,
|
||||||
|
smpl->params.mirostat_eta,
|
||||||
|
100,
|
||||||
|
smpl->vocab.n_vocab,
|
||||||
|
smpl->mirostat_mu);
|
||||||
|
} else if (type == 2) {
|
||||||
|
res = llama_sampling_sample_mirostat_v2_impl(candidates,
|
||||||
|
smpl->rng,
|
||||||
|
smpl->params.mirostat_tau,
|
||||||
|
smpl->params.mirostat_eta,
|
||||||
|
smpl->mirostat_mu);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid mirostat type: %d", type);
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = llama_sampling_sample_greedy_impl(candidates);
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & params = smpl->params;
|
||||||
|
|
||||||
|
const float temp = params.temp;
|
||||||
|
const int mirostat = params.mirostat;
|
||||||
|
|
||||||
|
auto & cur_p = candidates;
|
||||||
|
|
||||||
|
llama_token res = 0;
|
||||||
|
|
||||||
|
if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
|
||||||
|
// greedy sampling, with probs
|
||||||
|
llama_sampling_softmax_impl(cur_p);
|
||||||
|
res = cur_p->data[0].id;
|
||||||
|
} else if (temp == 0.0f) {
|
||||||
|
// greedy sampling, no probs
|
||||||
|
res = llama_sampling_sample_greedy(smpl, cur_p);
|
||||||
|
} else {
|
||||||
|
if (mirostat != 0) {
|
||||||
|
llama_sampling_temp(smpl, cur_p);
|
||||||
|
res = llama_sampling_sample_mirostat(smpl, cur_p);
|
||||||
|
} else {
|
||||||
|
for (const auto & sampler : smpl->samplers) {
|
||||||
|
switch (sampler) {
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
|
||||||
|
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
|
||||||
|
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
|
||||||
|
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
|
||||||
|
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
|
||||||
|
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
|
||||||
|
default : break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res = llama_sampling_sample_dist(smpl, cur_p);
|
||||||
|
|
||||||
|
//{
|
||||||
|
// const int n_top = 10;
|
||||||
|
// LOG("top %d candidates:\n", n_top);
|
||||||
|
|
||||||
|
// for (int i = 0; i < n_top; i++) {
|
||||||
|
// const llama_token id = cur_p.data[i].id;
|
||||||
|
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||||
|
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sampling_accept(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token token,
|
||||||
|
bool apply_grammar) {
|
||||||
|
time_meas tm(smpl->t_accept_us);
|
||||||
|
|
||||||
|
llama_sampling_accept_impl(*smpl, token, apply_grammar);
|
||||||
|
|
||||||
|
smpl->n_accept++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_sampling_n_prev(const struct llama_sampling * smpl) {
|
||||||
|
return llama_sampling_n_prev_impl(*smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
|
||||||
|
return llama_sampling_prev_impl(*smpl, ith);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_last(const struct llama_sampling * smpl) {
|
||||||
|
return llama_sampling_prev_impl(*smpl, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// model split
|
||||||
|
//
|
||||||
|
|
||||||
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||||
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
||||||
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
||||||
@ -20707,30 +20961,32 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) {
|
||||||
struct llama_timings result = {
|
const llama_timings timings = {
|
||||||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||||
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
||||||
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
||||||
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
|
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
|
||||||
|
/*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0),
|
||||||
|
/*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0),
|
||||||
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
||||||
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
||||||
|
|
||||||
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
|
/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0),
|
||||||
|
/*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0),
|
||||||
|
/*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0),
|
||||||
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
||||||
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_print_timings(struct llama_context * ctx) {
|
|
||||||
const llama_timings timings = llama_get_timings(ctx);
|
|
||||||
|
|
||||||
LLAMA_LOG_INFO("\n");
|
LLAMA_LOG_INFO("\n");
|
||||||
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
|
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
|
||||||
LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
__func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
|
__func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
|
||||||
|
LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar);
|
||||||
|
//LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
// __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept);
|
||||||
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
|
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
|
||||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
@ -20738,12 +20994,16 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||||||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
|
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_reset_timings(struct llama_context * ctx) {
|
void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) {
|
||||||
ctx->t_start_us = ggml_time_us();
|
ctx->t_start_us = ggml_time_us();
|
||||||
ctx->t_eval_us = ctx->n_eval = 0;
|
ctx->t_eval_us = ctx->n_eval = 0;
|
||||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||||
|
|
||||||
ctx->sampling.reset_timings();
|
if (smpl) {
|
||||||
|
smpl->t_sample_us = smpl->n_sample = 0;
|
||||||
|
smpl->t_grammar_us = smpl->n_grammar = 0;
|
||||||
|
smpl->t_accept_us = smpl->n_accept = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * llama_print_system_info(void) {
|
const char * llama_print_system_info(void) {
|
||||||
@ -20785,21 +21045,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
|
|||||||
1.0e-3 * ctx->t_eval_us / ctx->n_eval);
|
1.0e-3 * ctx->t_eval_us / ctx->n_eval);
|
||||||
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
|
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
|
||||||
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
|
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
|
||||||
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
|
|
||||||
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
|
|
||||||
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
|
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
|
||||||
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
|
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
|
||||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
|
|
||||||
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
|
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
|
||||||
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
|
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
|
||||||
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
|
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
|
||||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
|
|
||||||
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
|
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
|
||||||
1.0e6 * ctx->n_eval / ctx->t_eval_us);
|
1.0e6 * ctx->n_eval / ctx->t_eval_us);
|
||||||
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
|
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
|
||||||
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
|
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
|
||||||
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
|
|
||||||
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For internal test use
|
// For internal test use
|
||||||
|
@ -2,33 +2,18 @@
|
|||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "llama.h"
|
|
||||||
#include "grammar-parser.h"
|
|
||||||
#include "json-schema-to-grammar.h"
|
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
static llama_grammar* build_grammar(const std::string & grammar_str) {
|
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||||
|
|
||||||
// Ensure we parsed correctly
|
|
||||||
assert(!parsed_grammar.rules.empty());
|
|
||||||
|
|
||||||
// Ensure we have a root node
|
|
||||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
llama_grammar* grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
|
|
||||||
return grammar;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||||
@ -45,17 +30,15 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||||
auto decoded = decode_utf8(input, {});
|
const auto cpts = unicode_cpts_from_utf8(input);
|
||||||
|
|
||||||
const auto & code_points = decoded.first;
|
|
||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
||||||
|
|
||||||
if (cur_stacks.empty()) {
|
if (cur_stacks.empty()) {
|
||||||
// no stacks means that the grammar failed to match at this point
|
// no stacks means that the grammar failed to match at this point
|
||||||
@ -77,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||||||
fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
|
fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
|
|
||||||
auto grammar = build_grammar(grammar_str);
|
auto * grammar = build_grammar(grammar_str);
|
||||||
|
|
||||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||||
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
|
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
|
||||||
@ -143,7 +126,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clean up allocated memory
|
// Clean up allocated memory
|
||||||
llama_grammar_free(grammar);
|
llama_grammar_free_impl(grammar);
|
||||||
}
|
}
|
||||||
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
||||||
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
||||||
@ -683,7 +666,8 @@ static void test_failure_missing_root() {
|
|||||||
term ::= number
|
term ::= number
|
||||||
number ::= [0-9]+)""";
|
number ::= [0-9]+)""";
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
llama_grammar_parser parsed_grammar;
|
||||||
|
parsed_grammar.parse(grammar_str.c_str());
|
||||||
|
|
||||||
// Ensure we parsed correctly
|
// Ensure we parsed correctly
|
||||||
assert(!parsed_grammar.rules.empty());
|
assert(!parsed_grammar.rules.empty());
|
||||||
@ -705,7 +689,8 @@ static void test_failure_missing_reference() {
|
|||||||
|
|
||||||
fprintf(stderr, " Expected error: ");
|
fprintf(stderr, " Expected error: ");
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
llama_grammar_parser parsed_grammar;
|
||||||
|
parsed_grammar.parse(grammar_str.c_str());
|
||||||
|
|
||||||
// Ensure we did NOT parsed correctly
|
// Ensure we did NOT parsed correctly
|
||||||
assert(parsed_grammar.rules.empty());
|
assert(parsed_grammar.rules.empty());
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) {
|
|||||||
|
|
||||||
static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
|
static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
|
||||||
uint32_t index = 0;
|
uint32_t index = 0;
|
||||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
|
llama_grammar_parser parsed_grammar;
|
||||||
|
parsed_grammar.parse(grammar_bytes);
|
||||||
|
|
||||||
std::map<uint32_t, std::string> symbol_names;
|
std::map<uint32_t, std::string> symbol_names;
|
||||||
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
|
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
|
||||||
@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void verify_failure(const char *grammar_bytes) {
|
static void verify_failure(const char * grammar_bytes) {
|
||||||
fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
|
fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
|
||||||
auto result = grammar_parser::parse(grammar_bytes);
|
llama_grammar_parser result;
|
||||||
|
result.parse(grammar_bytes);
|
||||||
assert(result.rules.empty() && "should have failed");
|
assert(result.rules.empty() && "should have failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
#include "json-schema-to-grammar.h"
|
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
static std::string trim(const std::string & source) {
|
static std::string trim(const std::string & source) {
|
||||||
std::string s(source);
|
std::string s(source);
|
||||||
s.erase(0,s.find_first_not_of(" \n\r\t"));
|
s.erase(0,s.find_first_not_of(" \n\r\t"));
|
||||||
@ -40,7 +41,8 @@ struct TestCase {
|
|||||||
}
|
}
|
||||||
void verify_expectation_parseable() const {
|
void verify_expectation_parseable() const {
|
||||||
try {
|
try {
|
||||||
auto state = grammar_parser::parse(expected_grammar.c_str());
|
llama_grammar_parser state;
|
||||||
|
state.parse(expected_grammar.c_str());
|
||||||
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
|
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
|
||||||
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
|
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
|
||||||
}
|
}
|
||||||
|
@ -2,16 +2,15 @@
|
|||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define LLAMA_API_INTERNAL
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
grammar_parser::parse_state parsed_grammar;
|
llama_grammar_parser parsed_grammar;
|
||||||
|
|
||||||
std::vector<std::pair<std::string, uint32_t>> expected = {
|
std::vector<std::pair<std::string, uint32_t>> expected = {
|
||||||
{"expr", 2},
|
{"expr", 2},
|
||||||
@ -117,7 +116,7 @@ int main()
|
|||||||
llama_grammar * grammar = NULL;
|
llama_grammar * grammar = NULL;
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
|
|
||||||
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
if (grammar == nullptr)
|
if (grammar == nullptr)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
@ -174,13 +173,13 @@ int main()
|
|||||||
}};
|
}};
|
||||||
|
|
||||||
auto index = 0;
|
auto index = 0;
|
||||||
for (auto stack : llama_grammar_get_stacks(grammar))
|
for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
|
||||||
{
|
{
|
||||||
// compare stack to expected_stack
|
// compare stack to expected_stack
|
||||||
for (uint32_t i = 0; i < stack.size(); i++)
|
for (uint32_t i = 0; i < stack.size(); i++)
|
||||||
{
|
{
|
||||||
auto element = stack[i];
|
const llama_grammar_element * element = stack[i];
|
||||||
auto expected_element = expected_stacks[index][i];
|
const llama_grammar_element & expected_element = expected_stacks[index][i];
|
||||||
|
|
||||||
// pretty print error message before asserting
|
// pretty print error message before asserting
|
||||||
if (expected_element.type != element->type || expected_element.value != element->value)
|
if (expected_element.type != element->type || expected_element.value != element->value)
|
||||||
@ -403,6 +402,8 @@ int main()
|
|||||||
delete[] candidate.code_points;
|
delete[] candidate.code_points;
|
||||||
candidate.code_points = nullptr;
|
candidate.code_points = nullptr;
|
||||||
}
|
}
|
||||||
llama_grammar_free(grammar);
|
|
||||||
|
llama_grammar_free_impl(grammar);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
@ -20,6 +21,7 @@ static void dump(const llama_token_data_array * candidates) {
|
|||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -28,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sampling_softmax_impl(&candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
llama_sampling_top_k_impl(&candidates_p, k, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
@ -41,6 +43,7 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||||||
|
|
||||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -49,9 +52,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sampling_softmax_impl(&candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
llama_sampling_top_p_impl(&candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
@ -62,6 +65,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
|
|
||||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -71,7 +75,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
llama_sampling_tail_free_impl(&candidates_p, z, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
@ -82,6 +86,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||||||
|
|
||||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -91,9 +96,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_min_p(nullptr, &candidates_p, p, 1);
|
llama_sampling_min_p_impl(&candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sampling_softmax_impl(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
@ -103,6 +108,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
|
|
||||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -112,7 +118,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
llama_sampling_typical_impl(&candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
@ -121,13 +127,14 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_repetition_penalties(
|
static void test_penalties(
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||||
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
||||||
) {
|
) {
|
||||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||||
@ -135,11 +142,16 @@ static void test_repetition_penalties(
|
|||||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token_cnt token_count;
|
||||||
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
|
token_count[last_tokens[i]]++;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sampling_softmax_impl(&candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
llama_sampling_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sampling_softmax_impl(&candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
@ -148,8 +160,7 @@ static void test_repetition_penalties(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_sampler_queue(
|
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||||
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
|
|
||||||
) {
|
) {
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -165,16 +176,16 @@ static void test_sampler_queue(
|
|||||||
|
|
||||||
for (auto s : samplers_sequence) {
|
for (auto s : samplers_sequence) {
|
||||||
switch (s){
|
switch (s){
|
||||||
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
|
case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break;
|
||||||
case 'f': GGML_ABORT("tail_free test not implemented");
|
case 'f': GGML_ABORT("tail_free test not implemented");
|
||||||
case 'y': GGML_ABORT("typical test not implemented");
|
case 'y': GGML_ABORT("typical test not implemented");
|
||||||
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
|
case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break;
|
||||||
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
|
case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break;
|
||||||
case 't': GGML_ABORT("temperature test not implemented");
|
case 't': GGML_ABORT("temperature test not implemented");
|
||||||
default : GGML_ABORT("Unknown sampler");
|
default : GGML_ABORT("Unknown sampler");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
|
llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests
|
||||||
|
|
||||||
const int size = candidates_p.size;
|
const int size = candidates_p.size;
|
||||||
|
|
||||||
@ -259,13 +270,13 @@ int main(void) {
|
|||||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||||
|
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
|
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||||
|
|
||||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||||
|
Loading…
Reference in New Issue
Block a user