From 38aca9e1abac5564a5da71ee35a136ce76d2e29a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Oct 2023 21:22:31 +0300 Subject: [PATCH] llama : factor out tensor offloading outside the build call (wip) ggml-ci --- llama.cpp | 200 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 126 insertions(+), 74 deletions(-) diff --git a/llama.cpp b/llama.cpp index 43c629358..22129ee43 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3008,10 +3008,10 @@ static void llm_load_tensors( #ifdef GGML_USE_CUBLAS const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = hparams.n_layer + 3; + const int max_offloadable_layers = hparams.n_layer + 3; #elif defined(GGML_USE_CLBLAST) const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; #endif // GGML_USE_CUBLAS LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); @@ -3116,8 +3116,6 @@ static struct ggml_cgraph * llm_build_llama( const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; - const int n_gpu_layers = model.n_gpu_layers; - const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = worst_case ? n_ctx : kv_self.n; const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; @@ -3155,45 +3153,21 @@ static struct ggml_cgraph * llm_build_llama( } ggml_set_name(inpL, "inp_embd"); - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_set_name(KQ_scale, "KQ_scale"); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); ggml_set_name(KQ_mask, "KQ_mask"); // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); ggml_set_name(KQ_pos, "KQ_pos"); // shift the entire K-cache if needed if (do_rope_shift) { struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); ggml_set_name(K_shift, "K_shift"); for (int il = 0; il < n_layer; ++il) { @@ -3205,33 +3179,21 @@ static struct ggml_cgraph * llm_build_llama( ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), K_shift, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(tmp); + ggml_set_name(tmp, "K_shifted"); ggml_build_forward_expand(gf, tmp); } } for (int il = 0; il < n_layer; ++il) { - ggml_format_name(inpL, "layer_inp_%d", il); - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - struct ggml_tensor * inpSA = inpL; // norm { cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); - offload_func(cur); ggml_set_name(cur, "rms_norm_0"); // cur = cur*attn_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); ggml_set_name(cur, "attention_norm_0"); } @@ -3239,19 +3201,15 @@ static struct ggml_cgraph * llm_build_llama( { // compute Q and K and RoPE them struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); ggml_set_name(tmpk, "tmpk"); struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory @@ -3259,21 +3217,17 @@ static struct ggml_cgraph * llm_build_llama( // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); ggml_set_name(k, "k"); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! @@ -3282,7 +3236,6 @@ static struct ggml_cgraph * llm_build_llama( } struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = @@ -3291,28 +3244,23 @@ static struct ggml_cgraph * llm_build_llama( ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); ggml_set_name(K, "K"); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled shape [n_kv, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads @@ -3322,12 +3270,10 @@ static struct ggml_cgraph * llm_build_llama( ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); ggml_set_name(V, "V"); #if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); ggml_set_name(KQV, "KQV"); #else // make V contiguous in memory to speed up the matmul, however we waste time on the copy @@ -3339,24 +3285,20 @@ static struct ggml_cgraph * llm_build_llama( // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); // cur = KQV_merged.contiguous().view(n_embd, n_tokens) cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); - offload_func(cur); ggml_set_name(cur, "result_wo"); } struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - offload_func(inpFF); ggml_set_name(inpFF, "inpFF"); // feed-forward network @@ -3364,45 +3306,37 @@ static struct ggml_cgraph * llm_build_llama( // norm { cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); - offload_func(cur); ggml_set_name(cur, "rms_norm_1"); // cur = cur*ffn_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); ggml_set_name(cur, "ffn_norm"); } struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - offload_func(tmp); ggml_set_name(tmp, "result_w3"); cur = ggml_mul_mat(ctx0, model.layers[il].w1, cur); - offload_func(cur); ggml_set_name(cur, "result_w1"); // SILU activation cur = ggml_silu(ctx0, cur); - offload_func(cur); ggml_set_name(cur, "silu"); cur = ggml_mul(ctx0, cur, tmp); - offload_func(cur); ggml_set_name(cur, "silu_x_result_w3"); cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - offload_func(cur); ggml_set_name(cur, "result_w2"); } cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); ggml_set_name(cur, "inpFF_+_result_w2"); // input for next layer @@ -3414,12 +3348,10 @@ static struct ggml_cgraph * llm_build_llama( // norm { cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); - offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2"); // cur = cur*norm(broadcasted) cur = ggml_mul(ctx0, cur, model.output_norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend ggml_set_name(cur, "result_norm"); } @@ -3884,7 +3816,6 @@ static struct ggml_cgraph * llm_build_refact( for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); - offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -5641,7 +5572,7 @@ static struct ggml_cgraph * llama_build_graph( GGML_ASSERT(false); } - // set input data to the graph + // allocate memory and set the values for the input tensors of the graph // inp_tokens if (batch.token) { @@ -5655,7 +5586,10 @@ static struct ggml_cgraph * llama_build_graph( memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur)); } - } else { // inp_embd + } + + // inp_embd + if (batch.embd) { struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_embd"); GGML_ASSERT(cur != nullptr); @@ -5775,6 +5709,124 @@ static struct ggml_cgraph * llama_build_graph( } } while (0); + // offload layers + + { + const int n_layer = model.hparams.n_layer; + + const int n_gpu_layers = model.n_gpu_layers; + const int i_gpu_start = n_layer - n_gpu_layers; + + GGML_UNUSED(i_gpu_start); + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } + + offload_func = ggml_cuda_assign_buffers_no_alloc; +#endif // GGML_USE_CUBLAS + + static const std::unordered_map k_offload_func = { + { "KQ_mask", offload_func_kq }, + { "KQ_pos", offload_func_kq }, + { "K_shift", offload_func_kq }, + { "K_shifted", offload_func_kq }, + + { "rms_norm_0", offload_func }, + { "attention_norm_0", offload_func }, + + { "tmpk", offload_func_kq }, + { "tmpq", offload_func_kq }, + { "tmpv", offload_func_v }, + { "Kcur", offload_func_kq }, + { "Qcur", offload_func_kq }, + { "Vcur", offload_func_v }, + + { "k", offload_func_kq }, + { "v", offload_func_v }, + + { "Q", offload_func_kq }, + { "K", offload_func_kq }, + { "KQ", offload_func_kq }, + { "KQ_scaled", offload_func_kq }, + { "KQ_scaled_alibi", offload_func_kq }, + { "KQ_masked", offload_func_kq }, + { "KQ_soft_max", offload_func_v }, + { "V", offload_func_v }, + { "KQV", offload_func_v }, + { "KQV_merged", offload_func_v }, + { "KQV_merged_contiguous", offload_func_v }, + + { "result_wo", offload_func }, + + { "inpFF", offload_func }, + + { "rms_norm_1", offload_func }, + { "ffn_norm", offload_func }, + + { "result_w3", offload_func }, + { "result_w2", offload_func }, + { "result_w1", offload_func }, + { "silu", offload_func }, + { "silu_x_result_w3", offload_func }, + { "inpFF_+_result_w2", offload_func }, + + { "rms_norm_2", offload_func_nr }, + //{ "result_norm", offload_func_nr }, // TODO CPU + GPU mirrored backend + //{ "result_output", offload_func }, + }; + + static const std::unordered_map k_offload_func_name = { + { llama_nop, "CPU" }, +#ifdef GGML_USE_CUBLAS + { ggml_cuda_assign_buffers_no_alloc, "GPU (CUDA)" }, +#endif + }; + + std::unordered_map ofn; + + for (int i = 0; i < result->n_nodes; ++i) { + struct ggml_tensor * cur = result->nodes[i]; + + const std::string name = cur->name; + + if (k_offload_func.find(name) == k_offload_func.end()) { + if (worst_case && cur->view_src == nullptr) { + LLAMA_LOG_WARN("%s: %32s: not offloaded (ref: %s)\n", __func__, + name.c_str(), "https://github.com/ggerganov/llama.cpp/pull/3837"); + } + continue; + } + + offload_func_t f = k_offload_func.at(name); + if (f == offload_func) { + if (ofn[name]++ < i_gpu_start) { + f = llama_nop; + } + } + + f(cur); + + if (worst_case && cur->view_src == nullptr) { + LLAMA_LOG_INFO("%s: %32s: %s\n", __func__, name.c_str(), k_offload_func_name.at(f).c_str()); + } + } + } + return result; }