mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 16:40:34 +01:00
context : initial need_reserve logic
ggml-ci
This commit is contained in:
parent
2f6d767fc5
commit
1f7dc12222
@ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id(
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_context::kv_self_update() {
|
||||
bool need_reserve = false;
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
auto & kv = kv_self;
|
||||
|
||||
if (kv.has_shift) {
|
||||
@ -655,12 +653,14 @@ bool llama_context::kv_self_update() {
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
need_reserve = true;
|
||||
|
||||
kv.do_defrag = false;
|
||||
}
|
||||
|
||||
return need_reserve;
|
||||
need_reserve = true;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update();
|
||||
}
|
||||
|
||||
void llama_context::build_attn_inp(
|
||||
@ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec(
|
||||
return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
||||
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
||||
llama_kv_cache_view_update(view, ctx->kv_self);
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||
return llama_kv_self_n_tokens(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
return llama_kv_cache_n_tokens(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_self_used_cells(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_cache_used_cells(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_clear(llama_context * ctx) {
|
||||
llama_kv_self_clear(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
llama_kv_cache_clear(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_keep(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_keep(ctx, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_defrag(llama_context * ctx) {
|
||||
return llama_kv_self_defrag(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
return llama_kv_cache_defrag(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_self_can_shift(ctx);
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_cache_can_shift(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_update(llama_context * ctx) {
|
||||
llama_kv_self_update(ctx);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
||||
|
@ -62,6 +62,7 @@ struct llama_context {
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
bool logits_all = false;
|
||||
bool need_reserve = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
@ -87,6 +88,7 @@ struct llama_context {
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
// certain implementations could require a padding for the context size
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
void reset();
|
||||
@ -140,7 +142,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||
|
||||
// return true if need to reserve new worst-case graph
|
||||
bool kv_self_update();
|
||||
void kv_self_update();
|
||||
|
||||
void build_attn_inp(
|
||||
ggml_context * ctx0,
|
||||
|
337
src/llama.cpp
337
src/llama.cpp
@ -28,57 +28,6 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
|
||||
// loading time will be recalculated after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = 0;
|
||||
time_meas tm(model.t_load_us);
|
||||
|
||||
model.t_start_us = tm.t_start_us;
|
||||
|
||||
try {
|
||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
|
||||
ml.print_info();
|
||||
|
||||
model.hparams.vocab_only = params.vocab_only;
|
||||
|
||||
try {
|
||||
model.load_arch(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
||||
}
|
||||
try {
|
||||
model.load_hparams(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
model.load_stats(ml);
|
||||
model.print_info();
|
||||
|
||||
if (params.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!model.load_tensors(ml)) {
|
||||
return -2;
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// llm_build
|
||||
//
|
||||
@ -7951,6 +7900,30 @@ static int llama_decode_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
// TODO: extract to a function
|
||||
if (lctx.need_reserve) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
@ -8206,6 +8179,31 @@ static int llama_encode_impl(
|
||||
|
||||
lctx.prepare_decode(ubatch);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
// TODO: extract to a function
|
||||
if (lctx.need_reserve) {
|
||||
// TODO: extract to a function
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
||||
@ -8419,6 +8417,57 @@ int64_t llama_time_us(void) {
|
||||
return ggml_time_us();
|
||||
}
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
|
||||
// loading time will be recalculated after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = 0;
|
||||
time_meas tm(model.t_load_us);
|
||||
|
||||
model.t_start_us = tm.t_start_us;
|
||||
|
||||
try {
|
||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
|
||||
ml.print_info();
|
||||
|
||||
model.hparams.vocab_only = params.vocab_only;
|
||||
|
||||
try {
|
||||
model.load_arch(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
||||
}
|
||||
try {
|
||||
model.load_hparams(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
model.load_stats(ml);
|
||||
model.print_info();
|
||||
|
||||
if (params.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!model.load_tensors(ml)) {
|
||||
return -2;
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static struct llama_model * llama_model_load_from_file_impl(
|
||||
const std::string & path_model,
|
||||
std::vector<std::string> & splits,
|
||||
@ -8879,192 +8928,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
return llama_init_from_model(model, params);
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
||||
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
||||
llama_kv_cache_view_update(view, ctx->kv_self);
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||
return llama_kv_self_n_tokens(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
return llama_kv_cache_n_tokens(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_self_used_cells(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_cache_used_cells(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_clear(llama_context * ctx) {
|
||||
llama_kv_self_clear(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
llama_kv_cache_clear(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_keep(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_keep(ctx, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_defrag(llama_context * ctx) {
|
||||
return llama_kv_self_defrag(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
return llama_kv_cache_defrag(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_self_can_shift(ctx);
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_cache_can_shift(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_update(llama_context * ctx) {
|
||||
llama_kv_self_update(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
const bool need_reserve = ctx->kv_self_update();
|
||||
|
||||
// reserve a worst case graph again
|
||||
if (need_reserve) {
|
||||
// TODO: extract to a function
|
||||
const auto & cparams = ctx->cparams;
|
||||
const auto & model = ctx->model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(ctx->sched.get());
|
||||
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
int32_t llama_encode(
|
||||
|
Loading…
Reference in New Issue
Block a user