context : prepare for abstraction

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-17 21:11:03 +02:00
parent c6fa715709
commit a47d389c27
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 323 additions and 309 deletions

View File

@ -32,6 +32,309 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst) :
model(model),
t_start_us(model.t_start_us),
t_load_us (model.t_load_us) {
const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams));
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
hparams.n_ctx_train;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
}
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
} else {
cparams.pooling_type = hparams.pooling_type;
}
}
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}
if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}
logits_all = params.logits_all;
// build worst-case graph for encoder if a model contains encoder
is_encoding = llama_model_has_encoder(&model); // TODO: model.has_encoder()
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(&model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model.devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
throw std::runtime_error("failed to initialize backend");
}
backends.emplace_back(backend);
}
// add ACCEL backends (such as BLAS)
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
throw std::runtime_error("failed to initialize backend");
}
backends.emplace_back(backend);
}
}
// add CPU backend
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (backend_cpu == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
throw std::runtime_error("failed to initialize CPU backend");
}
backends.emplace_back(backend_cpu);
// create a list of the set_n_threads functions in the backends
for (auto & backend : backends) {
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
}
}
}
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
throw std::runtime_error("failed to initialize self-attention cache");
}
{
const size_t memory_size_k = kv_self.size_k_bytes();
const size_t memory_size_v = kv_self.size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// graph outputs buffer
{
// resized during inference when a batch uses more outputs
if (llama_output_reserve(*this, params.n_seq_max) < params.n_seq_max) {
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
throw std::runtime_error("failed to reserve initial output buffer");
}
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name (buf_output.get()),
ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
}
// scheduler and compute buffers
{
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_buffer_type_t> backend_buft;
std::vector<ggml_backend_t> backend_ptrs;
for (auto & backend : backends) {
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
// use the host buffer of the first device CPU for faster transfer of the intermediate state
auto * dev = model.devices[0];
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
if (host_buft) {
buft = host_buft;
}
}
backend_buft.push_back(buft);
backend_ptrs.push_back(backend.get());
}
const size_t max_nodes = model.max_nodes();
// buffer used to store the computation graph and the tensor meta data
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
model.n_devices() > 1 &&
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
params.offload_kqv;
// pipeline parallelism requires support for async compute and events in all devices
if (pipeline_parallel) {
for (auto & backend : backends) {
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
// ignore CPU backend
continue;
}
auto * dev = ggml_backend_get_device(backend.get());
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
if (!props.caps.async || !props.caps.events) {
// device does not support async compute or events
pipeline_parallel = false;
break;
}
}
}
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
if (pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
}
// initialize scheduler with the worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp);
// reserve pp graph first so that buffers are only allocated once
ggml_backend_sched_reserve(sched.get(), gf_pp);
int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
// reserve with tg graph to get the number of splits and nodes
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg);
ggml_backend_sched_reserve(sched.get(), gf_tg);
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
gf_pp = fn_build_graph_worst(*this, ubatch_pp);
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
throw std::runtime_error("failed to allocate compute buffers");
}
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
if (size > 1) {
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
}
}
if (n_nodes_pp == n_nodes_tg) {
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
} else {
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
}
if (n_splits_pp == n_splits_tg) {
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
} else {
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
}
}
}
struct llama_batch_manager : public llama_batch_manager_i {
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & hparams = lctx.model.hparams;
@ -81,7 +384,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_self_update(&lctx);
lctx.kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
@ -106,6 +409,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
return true;
}

View File

@ -30,11 +30,14 @@ struct llama_batch_manager_i {
virtual void finalize() = 0;
};
// TODO: make implementation details private
// TODO: become abstract base class, split the current implementation into different child classes
struct llama_context {
llama_context(const llama_model & model)
: model(model)
, t_start_us(model.t_start_us)
, t_load_us (model.t_load_us) {}
// TODO: store the worst-case graph build function and reuse it later
llama_context(
const llama_model & model,
const llama_context_params & params,
std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst);
const struct llama_model & model;

View File

@ -7893,8 +7893,6 @@ static int llama_decode_impl(
lctx.need_reserve = false;
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get());
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8564,309 +8562,17 @@ struct llama_context * llama_init_from_model(
return nullptr;
}
llama_context * ctx = new llama_context(*model);
llama_context * ctx = nullptr;
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
hparams.n_ctx_train;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
}
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
} else {
cparams.pooling_type = hparams.pooling_type;
}
}
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}
if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}
ctx->logits_all = params.logits_all;
// build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model->devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
llama_free(ctx);
return nullptr;
}
ctx->backends.emplace_back(backend);
}
// add ACCEL backends (such as BLAS)
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
llama_free(ctx);
return nullptr;
}
ctx->backends.emplace_back(backend);
}
}
// add CPU backend
ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (ctx->backend_cpu == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
llama_free(ctx);
return nullptr;
}
ctx->backends.emplace_back(ctx->backend_cpu);
// create a list of the set_n_threads functions in the backends
for (auto & backend : ctx->backends) {
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ctx->set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
}
}
}
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
if (!ctx->kv_self.init(ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
{
const size_t memory_size_k = ctx->kv_self.size_k_bytes();
const size_t memory_size_v = ctx->kv_self.size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// graph outputs buffer
{
// resized during inference when a batch uses more outputs
if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
llama_free(ctx);
return nullptr;
}
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(ctx->buf_output.get()),
ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
}
// scheduler and compute buffers
{
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_buffer_type_t> backend_buft;
std::vector<ggml_backend_t> backend_ptrs;
for (auto & backend : ctx->backends) {
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
// use the host buffer of the first device CPU for faster transfer of the intermediate state
auto * dev = model->devices[0];
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
if (host_buft) {
buft = host_buft;
}
}
backend_buft.push_back(buft);
backend_ptrs.push_back(backend.get());
}
const size_t max_nodes = model->max_nodes();
// buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
model->n_devices() > 1 &&
model->params.n_gpu_layers > (int)model->hparams.n_layer &&
model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
params.offload_kqv;
// pipeline parallelism requires support for async compute and events in all devices
if (pipeline_parallel) {
for (auto & backend : ctx->backends) {
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
// ignore CPU backend
continue;
}
auto * dev = ggml_backend_get_device(backend.get());
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
if (!props.caps.async || !props.caps.events) {
// device does not support async compute or events
pipeline_parallel = false;
break;
}
}
}
ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
if (pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
}
// initialize scheduler with the worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
// reserve pp graph first so that buffers are only allocated once
ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
int n_splits_pp = ggml_backend_sched_get_n_splits(ctx->sched.get());
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
// reserve with tg graph to get the number of splits and nodes
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
llama_free(ctx);
return nullptr;
}
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
size_t size = ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
if (size > 1) {
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
}
}
if (n_nodes_pp == n_nodes_tg) {
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
} else {
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
}
if (n_splits_pp == n_splits_tg) {
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
} else {
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
}
try {
// TODO: add logic which llama_context implementation to construct
ctx = new llama_context(*model, params,
[](llama_context & lctx, const llama_ubatch & ubatch) {
return llama_build_graph(lctx, ubatch, true);
});
} catch (const std::exception & e) {
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
return nullptr;
}
return ctx;