mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
parallel : process system prompt once + configurable paramters + llama API
This commit is contained in:
parent
82e20e9ba0
commit
4b5f3cd6bf
@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.n_chunks = std::stoi(argv[i]);
|
||||
} else if (arg == "-np" || arg == "--parallel") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_parallel = std::stoi(argv[i]);
|
||||
} else if (arg == "-ns" || arg == "--sequences") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_sequences = std::stoi(argv[i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
params.multiline_input = true;
|
||||
} else if (arg == "--simple-io") {
|
||||
params.simple_io = true;
|
||||
} else if (arg == "--hot-plug") {
|
||||
params.hot_plug = true;
|
||||
} else if (arg == "--color") {
|
||||
params.use_color = true;
|
||||
} else if (arg == "--mlock") {
|
||||
@ -659,6 +673,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||
printf(" --hot-plug enable hot-plugging of new sequences for decoding (default: disabled)\n");
|
||||
if (llama_mlock_supported()) {
|
||||
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||
}
|
||||
@ -781,7 +798,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
|
||||
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
|
||||
llama_kv_cache_rm_tokens(lctx, -1, -1);
|
||||
llama_kv_cache_tokens_rm(lctx, -1, -1);
|
||||
llama_reset_timings(lctx);
|
||||
}
|
||||
|
||||
@ -1253,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
|
||||
fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
|
||||
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
||||
fprintf(stream, "hot_plug: %s # default: false\n", params.hot_plug ? "true" : "false");
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", params.temp);
|
||||
|
||||
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
|
||||
|
@ -43,6 +43,8 @@ struct gpt_params {
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||
int32_t n_sequences = 1; // number of sequences to decode
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
@ -108,6 +110,7 @@ struct gpt_params {
|
||||
bool interactive_first = false; // wait for user input immediately
|
||||
bool multiline_input = false; // reverse the usage of `\`
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool hot_plug = false; // hot-plug new sequences for decoding
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool ignore_eos = false; // ignore generated EOS tokens
|
||||
|
@ -977,7 +977,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
test t(inst, lmodel, ctx);
|
||||
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
@ -988,7 +988,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
if (t.n_prompt > 0) {
|
||||
|
@ -505,8 +505,8 @@ int main(int argc, char ** argv) {
|
||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
|
@ -35,7 +35,7 @@ User: Hello, what is the temperature outside?
|
||||
Assistant: It is 72 degrees Fahrenheit.
|
||||
User: What is the definition of a prime number?
|
||||
Assistant: A prime number is a number that is divisible only by itself and 1.
|
||||
User: )";
|
||||
User:)";
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
@ -70,7 +70,7 @@ struct client {
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
std::vector<llama_token> tokens_prev;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@ -80,13 +80,14 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_clients = 8;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_plug = true;
|
||||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = 128;
|
||||
const int32_t n_seq = params.n_sequences;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_plug = params.hot_plug;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
@ -114,13 +115,17 @@ int main(int argc, char ** argv) {
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.last_tokens.resize(n_ctx);
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
client.tokens_prev.resize(n_ctx);
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const uint32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
std::vector<llama_token> batch_token;
|
||||
@ -134,6 +139,44 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
||||
LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system);
|
||||
LOG_TEE("\n");
|
||||
|
||||
{
|
||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
|
||||
for (size_t i = 0; i < n_tokens_system; ++i) {
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(0);
|
||||
}
|
||||
|
||||
llama_batch batch = {
|
||||
n_tokens_system,
|
||||
tokens_system.data(),
|
||||
nullptr,
|
||||
batch_pos.data(),
|
||||
batch_seq_id.data(),
|
||||
nullptr,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cachce to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
@ -148,7 +191,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded + client.n_prompt);
|
||||
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_logits.push_back(true);
|
||||
batch_clients.push_back(&client);
|
||||
@ -158,34 +201,36 @@ int main(int argc, char ** argv) {
|
||||
|
||||
if (batch_token.empty()) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (hot_plug || batch_token.empty()) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.seq_id = client.id;
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
|
||||
|
||||
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
|
||||
batch_token.push_back(prompt_tokens[i]);
|
||||
batch_pos.push_back(i);
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch_token.push_back(tokens_prompt[i]);
|
||||
batch_pos.push_back(i + n_tokens_system);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
batch_logits.push_back(false);
|
||||
}
|
||||
batch_logits.back() = true;
|
||||
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
@ -217,9 +262,10 @@ int main(int argc, char ** argv) {
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
if (n_batch == 1) {
|
||||
LOG_TEE("%s : failed to decode batch\n", __func__);
|
||||
const int ret = llama_decode(ctx, batch, params.n_threads);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -242,7 +288,7 @@ int main(int argc, char ** argv) {
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
@ -251,8 +297,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.last_tokens.erase(client.last_tokens.begin());
|
||||
client.last_tokens.push_back(id);
|
||||
client.tokens_prev.erase(client.tokens_prev.begin());
|
||||
client.tokens_prev.push_back(id);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
client.response += token_str;
|
||||
@ -271,7 +317,8 @@ int main(int argc, char ** argv) {
|
||||
client.response = client.response.substr(0, pos);
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
@ -207,7 +207,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
@ -568,7 +568,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
|
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
||||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||
++n_past_dft;
|
||||
|
||||
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||
++n_past_cur;
|
||||
|
||||
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||
++n_past_tgt;
|
||||
|
||||
|
121
llama.cpp
121
llama.cpp
@ -1328,7 +1328,7 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
|
||||
static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
|
||||
if (c0 < 0) c0 = 0;
|
||||
if (c1 < 0) c1 = cache.size;
|
||||
|
||||
@ -1338,7 +1338,7 @@ static void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0,
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_kv_cache_rm_seq(
|
||||
static void llama_kv_cache_seq_rm(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@ -1353,7 +1353,20 @@ static void llama_kv_cache_rm_seq(
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
static void llama_kv_cache_seq_cp(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.insert(seq_id_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||
cache.cells[i].pos = -1;
|
||||
@ -1362,7 +1375,7 @@ static void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_kv_cache_shift_seq(
|
||||
static void llama_kv_cache_seq_shift(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@ -4019,7 +4032,11 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
// - batch: batch to evaluate
|
||||
// - n_threads: number of threads to use
|
||||
//
|
||||
static bool llama_decode_internal(
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch,
|
||||
int n_threads) {
|
||||
@ -4027,7 +4044,7 @@ static bool llama_decode_internal(
|
||||
|
||||
if (n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
@ -4079,7 +4096,7 @@ static bool llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return false;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
@ -4203,7 +4220,14 @@ static bool llama_decode_internal(
|
||||
lctx.n_p_eval += n_tokens;
|
||||
}
|
||||
|
||||
return true;
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
if (!lctx.has_evaluated_once) {
|
||||
lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
|
||||
lctx.has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
@ -6920,20 +6944,24 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
||||
return ctx->kv_self.head;
|
||||
}
|
||||
|
||||
void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) {
|
||||
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
|
||||
void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1);
|
||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
|
||||
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta);
|
||||
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// Returns the *maximum* size of the state
|
||||
@ -7330,21 +7358,18 @@ int llama_eval(
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
|
||||
if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return 1;
|
||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads);
|
||||
if (ret != 0) {
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int llama_eval_embd(
|
||||
@ -7353,23 +7378,20 @@ int llama_eval_embd(
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
|
||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
|
||||
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return 1;
|
||||
const int ret = llama_decode_internal(*ctx, batch, n_threads);
|
||||
if (ret != 0) {
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
@ -7394,19 +7416,16 @@ int llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch,
|
||||
int n_threads) {
|
||||
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return 1;
|
||||
const int ret = llama_decode_internal(*ctx, batch, n_threads);
|
||||
if (ret != 0) {
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return ret;
|
||||
}
|
||||
|
||||
float * llama_get_logits(struct llama_context * ctx) {
|
||||
|
15
llama.h
15
llama.h
@ -322,17 +322,20 @@ extern "C" {
|
||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||
|
||||
// Remove all tokens data of cells in [c0, c1)
|
||||
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
|
||||
LLAMA_API void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
LLAMA_API void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
LLAMA_API void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly
|
||||
LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
||||
LLAMA_API void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
@ -391,6 +394,10 @@ extern "C" {
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error
|
||||
LLAMA_API int llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch,
|
||||
|
Loading…
Reference in New Issue
Block a user