mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 16:40:34 +01:00
context : initial need_reserve logic
ggml-ci
This commit is contained in:
parent
2f6d767fc5
commit
1f7dc12222
@ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_context::kv_self_update() {
|
void llama_context::kv_self_update() {
|
||||||
bool need_reserve = false;
|
|
||||||
|
|
||||||
auto & kv = kv_self;
|
auto & kv = kv_self;
|
||||||
|
|
||||||
if (kv.has_shift) {
|
if (kv.has_shift) {
|
||||||
@ -655,12 +653,14 @@ bool llama_context::kv_self_update() {
|
|||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
need_reserve = true;
|
|
||||||
|
|
||||||
kv.do_defrag = false;
|
kv.do_defrag = false;
|
||||||
}
|
|
||||||
|
|
||||||
return need_reserve;
|
need_reserve = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_update(llama_context * ctx) {
|
||||||
|
ctx->kv_self_update();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::build_attn_inp(
|
void llama_context::build_attn_inp(
|
||||||
@ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec(
|
|||||||
return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
|
return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// kv cache view
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
||||||
|
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
||||||
|
llama_kv_cache_view_update(view, ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// kv cache
|
||||||
|
//
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||||
|
return llama_kv_self_n_tokens(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||||
|
return llama_kv_cache_n_tokens(&ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||||
|
return llama_kv_self_used_cells(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||||
|
return llama_kv_cache_used_cells(&ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_clear(llama_context * ctx) {
|
||||||
|
llama_kv_self_clear(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_clear(llama_context * ctx) {
|
||||||
|
llama_kv_cache_clear(&ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
bool llama_kv_cache_seq_rm(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_self_seq_rm(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_seq_cp(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_seq_cp(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_seq_keep(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
return llama_kv_self_seq_keep(ctx, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_seq_add(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta) {
|
||||||
|
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_seq_add(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta) {
|
||||||
|
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_seq_div(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d) {
|
||||||
|
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_seq_div(
|
||||||
|
llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d) {
|
||||||
|
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_defrag(llama_context * ctx) {
|
||||||
|
return llama_kv_self_defrag(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_self_defrag(llama_context * ctx) {
|
||||||
|
return llama_kv_cache_defrag(&ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||||
|
return llama_kv_self_can_shift(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||||
|
return llama_kv_cache_can_shift(&ctx->kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
void llama_kv_cache_update(llama_context * ctx) {
|
||||||
|
llama_kv_self_update(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
// llama state API
|
// llama state API
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ struct llama_context {
|
|||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||||
|
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
bool need_reserve = false;
|
||||||
|
|
||||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
@ -87,6 +88,7 @@ struct llama_context {
|
|||||||
// max token position across all sequences in the current context
|
// max token position across all sequences in the current context
|
||||||
llama_pos pos_max() const;
|
llama_pos pos_max() const;
|
||||||
|
|
||||||
|
// certain implementations could require a padding for the context size
|
||||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||||
|
|
||||||
void reset();
|
void reset();
|
||||||
@ -140,7 +142,7 @@ struct llama_context {
|
|||||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||||
|
|
||||||
// return true if need to reserve new worst-case graph
|
// return true if need to reserve new worst-case graph
|
||||||
bool kv_self_update();
|
void kv_self_update();
|
||||||
|
|
||||||
void build_attn_inp(
|
void build_attn_inp(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
|
337
src/llama.cpp
337
src/llama.cpp
@ -28,57 +28,6 @@
|
|||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
|
||||||
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
|
|
||||||
// loading time will be recalculated after the first eval, so
|
|
||||||
// we take page faults deferred by mmap() into consideration
|
|
||||||
model.t_load_us = 0;
|
|
||||||
time_meas tm(model.t_load_us);
|
|
||||||
|
|
||||||
model.t_start_us = tm.t_start_us;
|
|
||||||
|
|
||||||
try {
|
|
||||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
|
|
||||||
|
|
||||||
ml.print_info();
|
|
||||||
|
|
||||||
model.hparams.vocab_only = params.vocab_only;
|
|
||||||
|
|
||||||
try {
|
|
||||||
model.load_arch(ml);
|
|
||||||
} catch(const std::exception & e) {
|
|
||||||
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
model.load_hparams(ml);
|
|
||||||
} catch(const std::exception & e) {
|
|
||||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
model.load_vocab(ml);
|
|
||||||
} catch(const std::exception & e) {
|
|
||||||
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
|
||||||
}
|
|
||||||
|
|
||||||
model.load_stats(ml);
|
|
||||||
model.print_info();
|
|
||||||
|
|
||||||
if (params.vocab_only) {
|
|
||||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!model.load_tensors(ml)) {
|
|
||||||
return -2;
|
|
||||||
}
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_build
|
// llm_build
|
||||||
//
|
//
|
||||||
@ -7951,6 +7900,30 @@ static int llama_decode_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserve a worst case graph if needed
|
||||||
|
// TODO: extract to a function
|
||||||
|
if (lctx.need_reserve) {
|
||||||
|
const auto & cparams = lctx.cparams;
|
||||||
|
const auto & model = lctx.model;
|
||||||
|
|
||||||
|
// build worst-case graph
|
||||||
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||||
|
|
||||||
|
// initialize scheduler with the worst-case graph
|
||||||
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
|
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
lctx.need_reserve = false;
|
||||||
|
}
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
@ -8206,6 +8179,31 @@ static int llama_encode_impl(
|
|||||||
|
|
||||||
lctx.prepare_decode(ubatch);
|
lctx.prepare_decode(ubatch);
|
||||||
|
|
||||||
|
// reserve a worst case graph if needed
|
||||||
|
// TODO: extract to a function
|
||||||
|
if (lctx.need_reserve) {
|
||||||
|
// TODO: extract to a function
|
||||||
|
const auto & cparams = lctx.cparams;
|
||||||
|
const auto & model = lctx.model;
|
||||||
|
|
||||||
|
// build worst-case graph
|
||||||
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||||
|
|
||||||
|
// initialize scheduler with the worst-case graph
|
||||||
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
|
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
lctx.need_reserve = false;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
||||||
@ -8419,6 +8417,57 @@ int64_t llama_time_us(void) {
|
|||||||
return ggml_time_us();
|
return ggml_time_us();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||||
|
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
|
||||||
|
// loading time will be recalculated after the first eval, so
|
||||||
|
// we take page faults deferred by mmap() into consideration
|
||||||
|
model.t_load_us = 0;
|
||||||
|
time_meas tm(model.t_load_us);
|
||||||
|
|
||||||
|
model.t_start_us = tm.t_start_us;
|
||||||
|
|
||||||
|
try {
|
||||||
|
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||||
|
|
||||||
|
ml.print_info();
|
||||||
|
|
||||||
|
model.hparams.vocab_only = params.vocab_only;
|
||||||
|
|
||||||
|
try {
|
||||||
|
model.load_arch(ml);
|
||||||
|
} catch(const std::exception & e) {
|
||||||
|
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
model.load_hparams(ml);
|
||||||
|
} catch(const std::exception & e) {
|
||||||
|
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
model.load_vocab(ml);
|
||||||
|
} catch(const std::exception & e) {
|
||||||
|
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
|
||||||
|
model.load_stats(ml);
|
||||||
|
model.print_info();
|
||||||
|
|
||||||
|
if (params.vocab_only) {
|
||||||
|
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!model.load_tensors(ml)) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
static struct llama_model * llama_model_load_from_file_impl(
|
static struct llama_model * llama_model_load_from_file_impl(
|
||||||
const std::string & path_model,
|
const std::string & path_model,
|
||||||
std::vector<std::string> & splits,
|
std::vector<std::string> & splits,
|
||||||
@ -8879,192 +8928,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
return llama_init_from_model(model, params);
|
return llama_init_from_model(model, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// kv cache view
|
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
|
||||||
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
|
||||||
llama_kv_cache_view_update(view, ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// kv cache
|
|
||||||
//
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
|
||||||
return llama_kv_self_n_tokens(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
||||||
return llama_kv_cache_n_tokens(&ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
|
||||||
return llama_kv_self_used_cells(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
||||||
return llama_kv_cache_used_cells(&ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_clear(llama_context * ctx) {
|
|
||||||
llama_kv_self_clear(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_clear(llama_context * ctx) {
|
|
||||||
llama_kv_cache_clear(&ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
bool llama_kv_cache_seq_rm(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_self_seq_rm(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_seq_cp(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_seq_cp(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_seq_keep(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id) {
|
|
||||||
return llama_kv_self_seq_keep(ctx, seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
||||||
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_seq_add(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta) {
|
|
||||||
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_seq_add(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta) {
|
|
||||||
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_seq_div(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d) {
|
|
||||||
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_seq_div(
|
|
||||||
llama_context * ctx,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d) {
|
|
||||||
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
||||||
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
||||||
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_defrag(llama_context * ctx) {
|
|
||||||
return llama_kv_self_defrag(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_defrag(llama_context * ctx) {
|
|
||||||
return llama_kv_cache_defrag(&ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
|
||||||
return llama_kv_self_can_shift(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
||||||
return llama_kv_cache_can_shift(&ctx->kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
// deprecated
|
|
||||||
void llama_kv_cache_update(llama_context * ctx) {
|
|
||||||
llama_kv_self_update(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_self_update(llama_context * ctx) {
|
|
||||||
const bool need_reserve = ctx->kv_self_update();
|
|
||||||
|
|
||||||
// reserve a worst case graph again
|
|
||||||
if (need_reserve) {
|
|
||||||
// TODO: extract to a function
|
|
||||||
const auto & cparams = ctx->cparams;
|
|
||||||
const auto & model = ctx->model;
|
|
||||||
|
|
||||||
// build worst-case graph
|
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
|
|
||||||
|
|
||||||
// initialize scheduler with the worst-case graph
|
|
||||||
ggml_backend_sched_reset(ctx->sched.get());
|
|
||||||
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///
|
///
|
||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
|
Loading…
Reference in New Issue
Block a user