mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
simple : add parallel decoding support
This commit is contained in:
parent
addae65fd4
commit
b377bf2266
@ -956,11 +956,11 @@ llama_token llama_sample_token(
|
||||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &cur_p, temp);
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &cur_p, temp);
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
@ -968,7 +968,7 @@ llama_token llama_sample_token(
|
||||
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
|
||||
llama_sample_typical (ctx, &cur_p, typical_p, 1);
|
||||
llama_sample_top_p (ctx, &cur_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &cur_p, temp);
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
|
||||
{
|
||||
const int n_top = 10;
|
||||
|
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
@ -183,11 +183,11 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
||||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
@ -195,7 +195,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const uint32_t n_tokens_system = tokens_system.size();
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cachce to all parallel sequences
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
|
@ -523,13 +523,13 @@ struct llama_server_context
|
||||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
else if (mirostat == 2)
|
||||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
}
|
||||
else
|
||||
@ -540,7 +540,7 @@ struct llama_server_context
|
||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
@ -32,12 +32,18 @@ int main(int argc, char ** argv) {
|
||||
params.prompt = "Hello my name is";
|
||||
}
|
||||
|
||||
// total length of the sequences including the prompt
|
||||
const int n_len = 32;
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = 2048;
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||
|
||||
if (model == NULL) {
|
||||
@ -47,20 +53,29 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const int max_context_size = llama_n_ctx(ctx);
|
||||
const int max_tokens_list_size = max_context_size - 4;
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||
|
||||
if ((int) tokens_list.size() > max_tokens_list_size) {
|
||||
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size);
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
|
||||
|
||||
// make sure wi
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
|
||||
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
||||
@ -68,66 +83,157 @@ int main(int argc, char ** argv) {
|
||||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = tokens_list[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
||||
}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
// The LLM keeps a contextual cache memory of previous token evaluation.
|
||||
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
|
||||
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
|
||||
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
|
||||
// we will store the parallel decoded sequences in this vector
|
||||
std::vector<std::string> streams(n_parallel);
|
||||
|
||||
const int n_gen = std::min(32, max_context_size);
|
||||
// remember the batch index of the last tokenn for each parallel sequence
|
||||
// we will use this to know which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
|
||||
int n_cur = 0;
|
||||
int n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
while (n_cur < n_gen) {
|
||||
// evaluate the transformer
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
while (n_cur <= n_len) {
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_cur += tokens_list.size();
|
||||
tokens_list.clear();
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
|
||||
// sample the next token
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
if (i_batch[i] < 0) {
|
||||
// the stream has already finished
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_token new_token_id = 0;
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto logits = llama_get_logits(ctx) + i_batch[i] * n_vocab;
|
||||
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
const int top_k = 40;
|
||||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp (ctx, &candidates_p, temp);
|
||||
|
||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||
|
||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
|
||||
// is it an end of stream ?
|
||||
// mark this stream as finished
|
||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||
i_batch[i] = -1;
|
||||
LOG_TEE("\n");
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("%s: stream %d finished", __func__, i);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (n_parallel == 1) {
|
||||
// print the new token :
|
||||
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = i;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
new_token_id = llama_sample_token_greedy(ctx , &candidates_p);
|
||||
|
||||
// is it an end of stream ?
|
||||
if (new_token_id == llama_token_eos(ctx)) {
|
||||
fprintf(stderr, " [end of text]\n");
|
||||
if (batch.n_tokens == 0) {
|
||||
// all streams are finished
|
||||
break;
|
||||
}
|
||||
|
||||
// print the new token :
|
||||
printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// push this new token for next evaluation
|
||||
tokens_list.push_back(new_token_id);
|
||||
n_cur += 1;
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("\n");
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
34
llama.cpp
34
llama.cpp
@ -4185,20 +4185,18 @@ static int llama_decode_internal(
|
||||
{
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
if (lctx.logits_all) {
|
||||
if (batch.logits) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
if (batch.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||
} else {
|
||||
// return result for just the last token
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
||||
}
|
||||
@ -5269,7 +5267,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
@ -5281,6 +5279,10 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
llama_sample_temp(ctx, candidates_p, temp);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) {
|
||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
||||
return;
|
||||
@ -7357,7 +7359,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||
int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
@ -7377,7 +7379,7 @@ int llama_eval(
|
||||
int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
float * embd,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
@ -7398,7 +7400,7 @@ int llama_eval_embd(
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id) {
|
||||
return {
|
||||
@ -7414,8 +7416,8 @@ struct llama_batch llama_batch_get_one(
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) {
|
||||
llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
|
||||
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
||||
|
||||
if (embd) {
|
||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
||||
|
15
llama.h
15
llama.h
@ -68,7 +68,7 @@ extern "C" {
|
||||
|
||||
// data used for batch inference
|
||||
typedef struct llama_batch {
|
||||
uint32_t n_tokens;
|
||||
int32_t n_tokens;
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
@ -370,7 +370,7 @@ extern "C" {
|
||||
LLAMA_API DEPRECATED(int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
@ -380,7 +380,7 @@ extern "C" {
|
||||
LLAMA_API DEPRECATED(int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
float * embd,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
@ -391,7 +391,7 @@ extern "C" {
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
@ -401,7 +401,7 @@ extern "C" {
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
// The rest of the llama_batch members are allocated with size n_tokens
|
||||
// All members are left uninitialized
|
||||
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
|
||||
LLAMA_API struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
@ -531,7 +531,10 @@ extern "C" {
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
LLAMA_API void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
LLAMA_API DEPRECATED(void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp),
|
||||
"Use llama_sample_temp instead");
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||
|
Loading…
Reference in New Issue
Block a user