From 3b4bab6a38502d9e68587c2c19f26472480ec4dd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Sep 2023 19:42:39 +0300 Subject: [PATCH] llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) --- ggml-metal.m | 48 +++++++++++++++++++++++++++++++++------ ggml-metal.metal | 59 +++++++++++++++++++++++++++++++++++++++++++----- ggml.c | 2 -- llama.cpp | 50 +++++++++++++++++++++++++++------------- 4 files changed, 128 insertions(+), 31 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1139ee311..d793083d9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -736,25 +736,59 @@ void ggml_metal_graph_compute( GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; + bool bcast_row = false; - if (ggml_nelements(src1) == ne10) { + int64_t nb = ne00; + + if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; [encoder setComputePipelineState:ctx->pipeline_add_row]; + + bcast_row = true; } else { [encoder setComputePipelineState:ctx->pipeline_add]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; - const int64_t n = ggml_nelements(dst)/4; + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } break; case GGML_OP_MUL: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 7f1c3d9ea..c913fe1d9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -24,12 +24,59 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +85,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } diff --git a/ggml.c b/ggml.c index fd3a51ba7..37124f1e5 100644 --- a/ggml.c +++ b/ggml.c @@ -8797,8 +8797,6 @@ static void ggml_compute_forward_add_f32( #else ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous diff --git a/llama.cpp b/llama.cpp index 0cab18093..f129c59f3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2404,6 +2404,7 @@ static struct ggml_cgraph * llm_build_llama( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -2411,6 +2412,22 @@ static struct ggml_cgraph * llm_build_llama( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + // KQ_mask + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < N; ++j) { + for (int i = n_past + j + 1; i < n_past + N; ++i) { + data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + } + } + } + } + for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2447,11 +2464,11 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -2502,17 +2519,18 @@ static struct ggml_cgraph * llm_build_llama( // KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); @@ -2783,8 +2801,8 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); break; case MODEL_13B: Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); @@ -2847,7 +2865,7 @@ static struct ggml_cgraph * llm_build_baichaun( // KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2856,7 +2874,7 @@ static struct ggml_cgraph * llm_build_baichaun( switch (model.type) { case MODEL_7B: - KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); break; case MODEL_13B: KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); @@ -2867,13 +2885,13 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(false); } // KQ_masked = mask_past(KQ_scaled) - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); // offload_func_kq(KQ_masked); // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); @@ -3179,9 +3197,9 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(tmpv); // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Kcur); { @@ -3220,15 +3238,15 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max");