llama : factor out tensor offloading outside the build call (wip)

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-28 21:22:31 +03:00
parent 5946d98fc8
commit 38aca9e1ab
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

200
llama.cpp
View File

@ -3008,10 +3008,10 @@ static void llm_load_tensors(
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
const int max_backend_supported_layers = hparams.n_layer + 3; const int max_backend_supported_layers = hparams.n_layer + 3;
const int max_offloadable_layers = hparams.n_layer + 3; const int max_offloadable_layers = hparams.n_layer + 3;
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
const int max_backend_supported_layers = hparams.n_layer + 1; const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1;
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
@ -3116,8 +3116,6 @@ static struct ggml_cgraph * llm_build_llama(
const float freq_scale = cparams.rope_freq_scale; const float freq_scale = cparams.rope_freq_scale;
const float norm_rms_eps = hparams.f_norm_rms_eps; const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens; const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = worst_case ? n_ctx : kv_self.n; const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
@ -3155,45 +3153,21 @@ static struct ggml_cgraph * llm_build_llama(
} }
ggml_set_name(inpL, "inp_embd"); ggml_set_name(inpL, "inp_embd");
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
// KQ_scale // KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "KQ_scale"); ggml_set_name(KQ_scale, "KQ_scale");
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_name(KQ_mask, "KQ_mask");
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_set_name(KQ_pos, "KQ_pos"); ggml_set_name(KQ_pos, "KQ_pos");
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift"); ggml_set_name(K_shift, "K_shift");
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3205,33 +3179,21 @@ static struct ggml_cgraph * llm_build_llama(
ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale); K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(tmp); ggml_set_name(tmp, "K_shifted");
ggml_build_forward_expand(gf, tmp); ggml_build_forward_expand(gf, tmp);
} }
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il);
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_0"); ggml_set_name(cur, "rms_norm_0");
// cur = cur*attn_norm(broadcasted) // cur = cur*attn_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
offload_func(cur);
ggml_set_name(cur, "attention_norm_0"); ggml_set_name(cur, "attention_norm_0");
} }
@ -3239,19 +3201,15 @@ static struct ggml_cgraph * llm_build_llama(
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk"); ggml_set_name(tmpk, "tmpk");
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
// store key and value to memory // store key and value to memory
@ -3259,21 +3217,17 @@ static struct ggml_cgraph * llm_build_llama(
// compute the transposed [n_tokens, n_embd] V matrix // compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv"); ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur"); ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
@ -3282,7 +3236,6 @@ static struct ggml_cgraph * llm_build_llama(
} }
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
struct ggml_tensor * K = struct ggml_tensor * K =
@ -3291,28 +3244,23 @@ static struct ggml_cgraph * llm_build_llama(
ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_kv, n_tokens, n_head, 1] // KQ_scaled shape [n_kv, n_tokens, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled"); ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled) // KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked"); ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked) // KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max"); ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads // split cached V into n_head heads
@ -3322,12 +3270,10 @@ static struct ggml_cgraph * llm_build_llama(
ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
#if 1 #if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV"); ggml_set_name(KQV, "KQV");
#else #else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy // make V contiguous in memory to speed up the matmul, however we waste time on the copy
@ -3339,24 +3285,20 @@ static struct ggml_cgraph * llm_build_llama(
// KQV_merged = KQV.permute(0, 2, 1, 3) // KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged"); ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, n_tokens) // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous"); ggml_set_name(cur, "KQV_merged_contiguous");
// projection (no bias) // projection (no bias)
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].wo, model.layers[il].wo,
cur); cur);
offload_func(cur);
ggml_set_name(cur, "result_wo"); ggml_set_name(cur, "result_wo");
} }
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
offload_func(inpFF);
ggml_set_name(inpFF, "inpFF"); ggml_set_name(inpFF, "inpFF");
// feed-forward network // feed-forward network
@ -3364,45 +3306,37 @@ static struct ggml_cgraph * llm_build_llama(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_1"); ggml_set_name(cur, "rms_norm_1");
// cur = cur*ffn_norm(broadcasted) // cur = cur*ffn_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
offload_func(cur);
ggml_set_name(cur, "ffn_norm"); ggml_set_name(cur, "ffn_norm");
} }
struct ggml_tensor * tmp = ggml_mul_mat(ctx0, struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model.layers[il].w3, model.layers[il].w3,
cur); cur);
offload_func(tmp);
ggml_set_name(tmp, "result_w3"); ggml_set_name(tmp, "result_w3");
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].w1, model.layers[il].w1,
cur); cur);
offload_func(cur);
ggml_set_name(cur, "result_w1"); ggml_set_name(cur, "result_w1");
// SILU activation // SILU activation
cur = ggml_silu(ctx0, cur); cur = ggml_silu(ctx0, cur);
offload_func(cur);
ggml_set_name(cur, "silu"); ggml_set_name(cur, "silu");
cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul(ctx0, cur, tmp);
offload_func(cur);
ggml_set_name(cur, "silu_x_result_w3"); ggml_set_name(cur, "silu_x_result_w3");
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].w2, model.layers[il].w2,
cur); cur);
offload_func(cur);
ggml_set_name(cur, "result_w2"); ggml_set_name(cur, "result_w2");
} }
cur = ggml_add(ctx0, cur, inpFF); cur = ggml_add(ctx0, cur, inpFF);
offload_func(cur);
ggml_set_name(cur, "inpFF_+_result_w2"); ggml_set_name(cur, "inpFF_+_result_w2");
// input for next layer // input for next layer
@ -3414,12 +3348,10 @@ static struct ggml_cgraph * llm_build_llama(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2"); ggml_set_name(cur, "rms_norm_2");
// cur = cur*norm(broadcasted) // cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.output_norm); cur = ggml_mul(ctx0, cur, model.output_norm);
// offload_func_nr(cur); // TODO CPU + GPU mirrored backend
ggml_set_name(cur, "result_norm"); ggml_set_name(cur, "result_norm");
} }
@ -3884,7 +3816,6 @@ static struct ggml_cgraph * llm_build_refact(
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
offload_func_t offload_func = llama_nop; offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
@ -5641,7 +5572,7 @@ static struct ggml_cgraph * llama_build_graph(
GGML_ASSERT(false); GGML_ASSERT(false);
} }
// set input data to the graph // allocate memory and set the values for the input tensors of the graph
// inp_tokens // inp_tokens
if (batch.token) { if (batch.token) {
@ -5655,7 +5586,10 @@ static struct ggml_cgraph * llama_build_graph(
memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur)); memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur));
} }
} else { // inp_embd }
// inp_embd
if (batch.embd) {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_embd"); struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_embd");
GGML_ASSERT(cur != nullptr); GGML_ASSERT(cur != nullptr);
@ -5775,6 +5709,124 @@ static struct ggml_cgraph * llama_build_graph(
} }
} while (0); } while (0);
// offload layers
{
const int n_layer = model.hparams.n_layer;
const int n_gpu_layers = model.n_gpu_layers;
const int i_gpu_start = n_layer - n_gpu_layers;
GGML_UNUSED(i_gpu_start);
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
offload_func = ggml_cuda_assign_buffers_no_alloc;
#endif // GGML_USE_CUBLAS
static const std::unordered_map<std::string, offload_func_t> k_offload_func = {
{ "KQ_mask", offload_func_kq },
{ "KQ_pos", offload_func_kq },
{ "K_shift", offload_func_kq },
{ "K_shifted", offload_func_kq },
{ "rms_norm_0", offload_func },
{ "attention_norm_0", offload_func },
{ "tmpk", offload_func_kq },
{ "tmpq", offload_func_kq },
{ "tmpv", offload_func_v },
{ "Kcur", offload_func_kq },
{ "Qcur", offload_func_kq },
{ "Vcur", offload_func_v },
{ "k", offload_func_kq },
{ "v", offload_func_v },
{ "Q", offload_func_kq },
{ "K", offload_func_kq },
{ "KQ", offload_func_kq },
{ "KQ_scaled", offload_func_kq },
{ "KQ_scaled_alibi", offload_func_kq },
{ "KQ_masked", offload_func_kq },
{ "KQ_soft_max", offload_func_v },
{ "V", offload_func_v },
{ "KQV", offload_func_v },
{ "KQV_merged", offload_func_v },
{ "KQV_merged_contiguous", offload_func_v },
{ "result_wo", offload_func },
{ "inpFF", offload_func },
{ "rms_norm_1", offload_func },
{ "ffn_norm", offload_func },
{ "result_w3", offload_func },
{ "result_w2", offload_func },
{ "result_w1", offload_func },
{ "silu", offload_func },
{ "silu_x_result_w3", offload_func },
{ "inpFF_+_result_w2", offload_func },
{ "rms_norm_2", offload_func_nr },
//{ "result_norm", offload_func_nr }, // TODO CPU + GPU mirrored backend
//{ "result_output", offload_func },
};
static const std::unordered_map<offload_func_t, std::string> k_offload_func_name = {
{ llama_nop, "CPU" },
#ifdef GGML_USE_CUBLAS
{ ggml_cuda_assign_buffers_no_alloc, "GPU (CUDA)" },
#endif
};
std::unordered_map<std::string, int> ofn;
for (int i = 0; i < result->n_nodes; ++i) {
struct ggml_tensor * cur = result->nodes[i];
const std::string name = cur->name;
if (k_offload_func.find(name) == k_offload_func.end()) {
if (worst_case && cur->view_src == nullptr) {
LLAMA_LOG_WARN("%s: %32s: not offloaded (ref: %s)\n", __func__,
name.c_str(), "https://github.com/ggerganov/llama.cpp/pull/3837");
}
continue;
}
offload_func_t f = k_offload_func.at(name);
if (f == offload_func) {
if (ofn[name]++ < i_gpu_start) {
f = llama_nop;
}
}
f(cur);
if (worst_case && cur->view_src == nullptr) {
LLAMA_LOG_INFO("%s: %32s: %s\n", __func__, name.c_str(), k_offload_func_name.at(f).c_str());
}
}
}
return result; return result;
} }