mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 05:42:22 +01:00
llama : minor sampling refactor (2) (#9386)
This commit is contained in:
parent
38ca6f644b
commit
5fb5e24811
@ -140,8 +140,6 @@ while n_cur <= n_len {
|
||||
|
||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id)
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
i_batch[i] = -1
|
||||
|
@ -172,8 +172,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
i_batch[i] = -1;
|
||||
|
@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
||||
llama_decode(ctx, bat);
|
||||
|
||||
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
|
@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||
// sample the most likely token
|
||||
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
||||
|
||||
llama_sampler_accept(sampler, new_token_id);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
return nullptr;
|
||||
|
@ -152,8 +152,6 @@ actor LlamaContext {
|
||||
|
||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||
|
||||
llama_sampler_accept(sampling, new_token_id)
|
||||
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
print("\n")
|
||||
is_done = true
|
||||
|
@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
LOG_TEE("\n");
|
||||
|
@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
llama_sampler_accept(smpl, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
|
||||
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
llama_sampler_accept(smpl2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
|
||||
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
llama_sampler_accept(smpl3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
|
@ -613,7 +613,7 @@ struct server_context {
|
||||
|
||||
gpt_params params;
|
||||
|
||||
llama_batch batch;
|
||||
llama_batch batch = {};
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
LOG_TEE("\n");
|
||||
|
@ -1127,15 +1127,16 @@ extern "C" {
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
||||
// Shorthand for:
|
||||
/// @details Sample and accept a token from the idx-th output of the last evaluation
|
||||
//
|
||||
// Shorthand for:
|
||||
// const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
// llama_token_data_array cur_p = { ... init from logits ... };
|
||||
// llama_sampler_apply(smpl, &cur_p);
|
||||
// return cur_p.data[cur_p.selected].id;
|
||||
//
|
||||
// At this point, this is mostly a convenience function.
|
||||
//
|
||||
// auto token = cur_p.data[cur_p.selected].id;
|
||||
// llama_sampler_accept(smpl, token);
|
||||
// return token;
|
||||
// Returns the sampled token
|
||||
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
||||
|
||||
// TODO: extend in the future
|
||||
|
@ -8,49 +8,44 @@
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
|
||||
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
|
||||
#if 1
|
||||
probs.resize(cur_p->size);
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
probs[i] = cur_p->data[i].p;
|
||||
}
|
||||
|
||||
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
|
||||
#else
|
||||
// avoid the copy with a custom iterator
|
||||
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
||||
// iterator for the probabilities
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
||||
#endif
|
||||
|
||||
struct probs_iterator {
|
||||
typedef std::input_iterator_tag iterator_category;
|
||||
typedef float value_type;
|
||||
typedef float * pointer;
|
||||
typedef float & reference;
|
||||
typedef size_t difference_type;
|
||||
typedef ptrdiff_t difference_type;
|
||||
|
||||
const llama_token_data_array * data;
|
||||
size_t i;
|
||||
const llama_token_data * data;
|
||||
|
||||
bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
|
||||
bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
|
||||
float operator*() const { return data->data[i].p; }
|
||||
probs_iterator & operator++() { ++i; return *this; }
|
||||
probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
|
||||
bool operator==(const probs_iterator & other) const { return data == other.data; }
|
||||
bool operator!=(const probs_iterator & other) const { return data != other.data; }
|
||||
const float & operator*() const { return data->p; }
|
||||
probs_iterator & operator++() { ++data; return *this; }
|
||||
probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
|
||||
};
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
|
||||
|
||||
GGML_UNUSED(probs);
|
||||
#endif
|
||||
|
||||
std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
|
||||
|
||||
return dist(rng);
|
||||
}
|
||||
|
||||
/*
|
||||
static void llama_log_softmax(float * array, size_t size) {
|
||||
float max_l = *std::max_element(array, array + size);
|
||||
float sum = 0.f;
|
||||
@ -64,6 +59,7 @@ static void llama_log_softmax(float * array, size_t size) {
|
||||
array[i] = logf(array[i] / sum);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(cur_p->size > 0);
|
||||
@ -231,67 +227,92 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ cur.data(),
|
||||
/* .size = */ cur.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(smpl, &cur_p);
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
||||
|
||||
auto token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
// sampler chain
|
||||
|
||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "chain";
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
|
||||
chain->n_sample++;
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_apply(smpl, cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_reset(smpl);
|
||||
}
|
||||
|
||||
chain->t_sample_us = 0;
|
||||
chain->n_sample = 0;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
||||
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_chain_init(chain_src->params);
|
||||
|
||||
for (auto * smpl : chain_src->samplers) {
|
||||
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_free(smpl);
|
||||
}
|
||||
|
||||
delete chain;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
|
||||
chain->n_sample++;
|
||||
},
|
||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_apply(smpl, cur_p);
|
||||
}
|
||||
},
|
||||
/* .reset = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_reset(smpl);
|
||||
}
|
||||
|
||||
chain->t_sample_us = 0;
|
||||
chain->n_sample = 0;
|
||||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_chain_init(chain_src->params);
|
||||
|
||||
for (auto * smpl : chain_src->samplers) {
|
||||
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_free(smpl);
|
||||
}
|
||||
|
||||
delete chain;
|
||||
},
|
||||
/* .name = */ llama_sampler_chain_name,
|
||||
/* .accept = */ llama_sampler_chain_accept,
|
||||
/* .apply = */ llama_sampler_chain_apply,
|
||||
/* .reset = */ llama_sampler_chain_reset,
|
||||
/* .clone = */ llama_sampler_chain_clone,
|
||||
/* .free = */ llama_sampler_chain_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
||||
@ -368,8 +389,6 @@ struct llama_sampler_dist {
|
||||
const uint32_t seed;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
std::vector<float> probs; // work array
|
||||
};
|
||||
|
||||
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
||||
@ -378,7 +397,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
||||
|
||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
||||
@ -419,7 +438,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
/* .ctx = */ new llama_sampler_dist {
|
||||
/* .seed = */ seed,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .probs = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
@ -1023,8 +1041,6 @@ struct llama_sampler_mirostat {
|
||||
float mu;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
std::vector<float> probs;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
||||
@ -1055,7 +1071,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
|
||||
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
|
||||
cur_p->selected = idx;
|
||||
|
||||
@ -1111,7 +1127,6 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
|
||||
/* .m = */ m,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .probs = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
@ -1127,8 +1142,6 @@ struct llama_sampler_mirostat_v2 {
|
||||
float mu;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
std::vector<float> probs;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
|
||||
@ -1152,7 +1165,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
|
||||
cur_p->selected = idx;
|
||||
|
||||
@ -1207,7 +1220,6 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
|
||||
/* .eta = */ eta,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .probs = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
@ -1527,6 +1539,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /
|
||||
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
||||
|
||||
if (ctx->logit_bias.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->to_search.clear();
|
||||
|
||||
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
|
||||
@ -1538,6 +1554,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->to_search.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// search for the remaining candidates that were not found in the previous step
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
for (const auto & lb : ctx->to_search) {
|
||||
|
@ -245,7 +245,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
||||
}
|
||||
}
|
||||
|
||||
printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
|
||||
printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
|
||||
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user