From 3ebb00935f3f0522b75df49c2769ab1774b91380 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Tue, 15 Aug 2023 06:14:14 +0800 Subject: [PATCH 01/10] server : add missing /json-schema-to-grammar.mjs (#2616) fixes #2611 --- examples/server/server.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 222dbcb43..99660455a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,6 +15,7 @@ #include "index.html.hpp" #include "index.js.hpp" #include "completion.js.hpp" +#include "json-schema-to-grammar.mjs.hpp" #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -1218,6 +1219,12 @@ int main(int argc, char **argv) res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript"); return false; }); + // this is only called if no index.html is found in the public --path + svr.Get("/json-schema-to-grammar.mjs", [](const Request &, Response &res) + { + res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript"); + return false; }); + svr.Post("/completion", [&llama](const Request &req, Response &res) { auto lock = llama.lock(); From b5ffb2849d23afe73647f68eec7b68187af09be6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Aug 2023 10:04:58 +0300 Subject: [PATCH 02/10] scripts : add helper script to get wikitext --- scripts/get-wikitext-2.sh | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 scripts/get-wikitext-2.sh diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh new file mode 100644 index 000000000..98aec3e3e --- /dev/null +++ b/scripts/get-wikitext-2.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip From bf83bff6742c0f1795b4c18695a13a34ac7adf62 Mon Sep 17 00:00:00 2001 From: Shouzheng Liu Date: Wed, 16 Aug 2023 16:07:04 -0400 Subject: [PATCH 03/10] metal : matrix-matrix multiplication kernel (#2615) * metal: matrix-matrix multiplication kernel This commit removes MPS and uses custom matrix-matrix multiplication kernels for all quantization types. This commit also adds grouped-query attention to support llama2 70B. * metal: fix performance degradation from gqa Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future. * metal: fix bugs for GQA and perplexity test. I mixed up ne02 and nb02 in previous commit. --- CMakeLists.txt | 2 - Makefile | 2 +- flake.nix | 2 - ggml-metal.m | 171 +++------ ggml-metal.metal | 969 +++++++++++++++++++++++------------------------ llama.cpp | 18 +- 6 files changed, 528 insertions(+), 636 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dff4942cd..01b40c2e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,7 +296,6 @@ if (LLAMA_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) @@ -313,7 +312,6 @@ if (LLAMA_METAL) ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK} - ${METALPERFORMANCE_FRAMEWORK} ) endif() diff --git a/Makefile b/Makefile index 070ae1242..5b801d16f 100644 --- a/Makefile +++ b/Makefile @@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL - LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders + LDFLAGS += -framework Foundation -framework Metal -framework MetalKit OBJS += ggml-metal.o endif # LLAMA_METAL diff --git a/flake.nix b/flake.nix index 4178e97ff..616b90252 100644 --- a/flake.nix +++ b/flake.nix @@ -14,8 +14,6 @@ with pkgs.darwin.apple_sdk_11_0.frameworks; [ Accelerate MetalKit - MetalPerformanceShaders - MetalPerformanceShadersGraph ] else if isAarch32 && isDarwin then with pkgs.darwin.apple_sdk.frameworks; [ diff --git a/ggml-metal.m b/ggml-metal.m index fbac21e3a..e13cb4b3c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -5,7 +5,6 @@ #import #import -#import #undef MIN #undef MAX @@ -79,6 +78,14 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->n_buffers = 0; ctx->concur_list_len = 0; - // determine if we can use MPS - if (MPSSupportsMTLDevice(ctx->device)) { - fprintf(stderr, "%s: using MPS\n", __func__); - } else { - fprintf(stderr, "%s: not using MPS\n", __func__); - GGML_ASSERT(false && "MPS not supported"); - } #if 0 // compile from source string and show compile log @@ -196,6 +196,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -506,7 +514,7 @@ void ggml_metal_graph_compute( id command_buffer = command_buffers[cb_idx]; - id encoder = nil; + id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; @@ -515,10 +523,6 @@ void ggml_metal_graph_compute( const int i = has_concur ? ctx->concur_list[ind] : ind; if (i == -1) { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - continue; - } [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; continue; } @@ -592,10 +596,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_ADD: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - if (ggml_nelements(src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_add_row]; @@ -613,10 +613,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_MUL: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - if (ggml_nelements(src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_mul_row]; @@ -634,10 +630,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -653,10 +645,6 @@ void ggml_metal_graph_compute( switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_SILU: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - [encoder setComputePipelineState:ctx->pipeline_silu]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -667,10 +655,6 @@ void ggml_metal_graph_compute( } break; case GGML_UNARY_OP_RELU: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - [encoder setComputePipelineState:ctx->pipeline_relu]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -681,10 +665,6 @@ void ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - [encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -701,10 +681,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const int nth = 32; [encoder setComputePipelineState:ctx->pipeline_soft_max]; @@ -719,10 +695,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_DIAG_MASK_INF: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const int n_past = ((int32_t *)(dst->op_params))[0]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; @@ -740,53 +712,43 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne00 == ne10); // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere + uint gqa = ne12/ne02; GGML_ASSERT(ne03 == ne13); + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && - (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { - - if (encoder != nil) { - [encoder endEncoding]; - encoder = nil; + src1t == GGML_TYPE_F32 && + [ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00%32 == 0 && + ne11 > 1) { + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } - - MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; - MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; - - // for F32 x F32 we use MPS - MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; - - MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; - - MPSMatrixDescriptor * desc = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32]; - - MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] - initWithDevice:ctx->device transposeLeft:false transposeRight:true - resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; - - // we need to do ne12 multiplications - // TODO: is there a way to do this in parallel - currently very slow .. - // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS - for (int64_t i02 = 0; i02 < ne12; ++i02) { - size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now - size_t offs_src1_cur = offs_src1 + i02*nb12; - size_t offs_dst_cur = offs_dst + i02*nb2; - - MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; - MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; - MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ]; - - [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; - } - } else { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - + else { int nth0 = 32; int nth1 = 1; @@ -885,23 +847,24 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -910,10 +873,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_GET_ROWS: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; @@ -939,10 +898,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_RMS_NORM: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -962,10 +917,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_NORM: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const float eps = 1e-5f; const int nth = 256; @@ -984,10 +935,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_ALIBI: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - GGML_ASSERT((src0t == GGML_TYPE_F32)); const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); @@ -1027,10 +974,6 @@ void ggml_metal_graph_compute( } break; case GGML_OP_ROPE: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -1071,10 +1014,6 @@ void ggml_metal_graph_compute( case GGML_OP_CPY: case GGML_OP_CONT: { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - } - const int nth = 32; switch (src0t) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 8d26b5ec2..3f3125236 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -18,47 +18,6 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { - const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const half d = x[i].d; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) { - const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const half d = x[i].d; - const half m = x[i].m; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - kernel void kernel_add( device const float * src0, device const float * src1, @@ -219,54 +178,6 @@ kernel void kernel_diag_mask_inf( } } -kernel void kernel_get_rows_f16( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - for (int j = 0; j < ne00; j++) { - dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j]; - } -} - -kernel void kernel_get_rows_q4_0( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_0( - (device const block_q4_0 *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q4_1( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_1( - (device const block_q4_1 *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - kernel void kernel_norm( device const void * src0, device float * dst, @@ -432,14 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, - uint2 tgpig, uint tiisg, uint sgitg) { + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, + uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -470,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } @@ -480,13 +393,17 @@ kernel void kernel_mul_mat_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -494,13 +411,17 @@ kernel void kernel_mul_mat_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -869,354 +790,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { return r; } -//========================================== dequantization ============================= - -static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * q = x[i].qs; - -#if QK_K == 256 - int is = 0; - float dl, ml; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - uint8_t sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; - - sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; - - shift += 2; - } - q += 32; - } -#else - float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); - float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); - float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); - float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); - for (int l = 0; l < 16; ++l) { - y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1; - y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2; - y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3; - y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4; - } - y += QK_K; -#endif - - } -} - -static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - -#if QK_K == 256 - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - uint16_t aux[8]; - thread const int8_t * scales = (thread const int8_t*)aux; - - for (int i = 0; i < nb; i++) { - - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs; - device const uint8_t * h = x[i].hmask; - uint8_t m = 1; - - device const uint16_t * a = (device const uint16_t *)x[i].scales; - aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); - aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); - aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4); - aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4); - aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4); - aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4); - aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); - aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); - - int is = 0; - float dl; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); - } - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); - } - - shift += 2; - m <<= 1; - } - q += 32; - } - } -#else - for (int i = 0; i < nb; i++) { - - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs; - device const uint8_t * hm = x[i].hmask; - - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - - for (int l = 0; l < 8; ++l) { - uint8_t h = hm[l]; - y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); - y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); - y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); - y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); - y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); - y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); - y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); - y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); - } - y += QK_K; - } -#endif - -} - -static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - device const uint8_t * q = x[i].qs; - -#if QK_K == 256 - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * scales = x[i].scales; - - int is = 0; - for (int j = 0; j < QK_K; j += 64) { - const uchar4 sc = get_scale_min_k4(is, scales); - const float d1 = d * sc[0]; const float m1 = min * sc[1]; - const float d2 = d * sc[2]; const float m2 = min * sc[3]; - for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; - q += 32; is += 2; - } -#else - device const uint8_t * s = x[i].scales; - device const half2 * dh = (device const half2 *)x[i].d; - const float2 d = (float2)dh[0]; - const float d1 = d[0] * (s[0] & 0xF); - const float d2 = d[0] * (s[1] & 0xF); - const float m1 = d[1] * (s[0] >> 4); - const float m2 = d[1] * (s[1] >> 4); - for (int l = 0; l < 32; ++l) { - y[l+ 0] = d1 * (q[l] & 0xF) - m1; - y[l+32] = d2 * (q[l] >> 4) - m2; - } - y += QK_K; -#endif - - } -} - -static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - -#if QK_K == 256 - for (int i = 0; i < nb; i++) { - - const float d = (float)(x[i].d); - const float min = (float)(x[i].dmin); - - device const uint8_t * ql = x[i].qs; - device const uint8_t * qh = x[i].qh; - - int is = 0; - uint8_t u1 = 1, u2 = 2; - for (int j = 0; j < QK_K; j += 64) { - const uchar4 sc = get_scale_min_k4(is, x[i].scales); - const float d1 = d * sc[0]; const float m1 = min * sc[1]; - const float d2 = d * sc[2]; const float m2 = min * sc[3]; - for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; - ql += 32; is += 2; - u1 <<= 2; u2 <<= 2; - } - } -#else - for (int i = 0; i < nb; i++) { - - const float d = (float)x[i].d; - - device const uint8_t * ql = x[i].qs; - device const uint8_t * qh = x[i].qh; - device const int8_t * sc = x[i].scales; - - for (int l = 0; l < 8; ++l) { - y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); - y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); - y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); - y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); - y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); - y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); - y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); - y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); - } - y += QK_K; - } -#endif - -} - -static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - device const uint8_t * ql = x[i].ql; - device const uint8_t * qh = x[i].qh; - device const int8_t * sc = x[i].scales; - - const float d = x[i].d; - -#if QK_K == 256 - for (int n = 0; n < QK_K; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l + 0] = d * sc[is + 0] * q1; - y[l + 32] = d * sc[is + 2] * q2; - y[l + 64] = d * sc[is + 4] * q3; - y[l + 96] = d * sc[is + 6] * q4; - } - y += 128; - ql += 64; - qh += 32; - sc += 8; - } -#else - for (int l = 0; l < 16; ++l) { - const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l+ 0] = d * sc[0] * q1; - y[l+16] = d * sc[1] * q2; - y[l+32] = d * sc[2] * q3; - y[l+48] = d * sc[3] * q4; - } - y += 64; -#endif - } -} - -kernel void kernel_get_rows_q2_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q2_K( - (device const block_q2_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q3_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q3_K( - (device const block_q3_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q4_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_K( - (device const block_q4_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q5_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q5_K( - (device const block_q5_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q6_K( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q6_K( - (device const block_q6_K *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - //====================================== dot products ========================= kernel void kernel_mul_mat_q2_K_f32( @@ -1224,21 +797,27 @@ kernel void kernel_mul_mat_q2_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -1351,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1362,10 +941,14 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne1, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1373,11 +956,12 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; @@ -1465,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } } @@ -1475,10 +1059,14 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne1, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1486,11 +1074,12 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; const int row = 2 * r0 + sgitg; - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 @@ -1529,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; } } @@ -1541,10 +1130,14 @@ kernel void kernel_mul_mat_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1560,10 +1153,12 @@ kernel void kernel_mul_mat_q4_K_f32( const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; @@ -1630,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1640,10 +1235,14 @@ kernel void kernel_mul_mat_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1653,10 +1252,12 @@ kernel void kernel_mul_mat_q4_K_f32( const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; - device const float * y = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[8]; float yh[8]; float sumf[N_DST]={0.f}, all_sum; @@ -1712,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = all_sum; + dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1723,9 +1324,14 @@ kernel void kernel_mul_mat_q5_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1733,11 +1339,12 @@ kernel void kernel_mul_mat_q5_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf[2]={0.f}; @@ -1871,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } @@ -1882,9 +1489,14 @@ kernel void kernel_mul_mat_q6_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - uint2 tgpig[[threadgroup_position_in_grid]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1897,11 +1509,12 @@ kernel void kernel_mul_mat_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; const int row = 2 * r0 + sgitg; - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf = 0; @@ -1967,6 +1580,366 @@ kernel void kernel_mul_mat_q6_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; } } + +//============================= templates and their specializations ============================= + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const half d = il ? (xb->d / 16.h) : xb->d; + const half m = il ? (-8.h * 16.h) : -8.h; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = il ? 0xF000 : 0x0F00; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d; + reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const half d = il ? (xb->d / 16.h) : xb->d; + const half m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = il ? 0xF000 : 0x0F00; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m; + reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const half d = xb->d; + const half min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + half dl, ml; + uint8_t sc = xb->scales[il]; + +#if QK_K == 256 + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; +#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const float d_all = (float)(xb->d); + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ + (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + + il = (il/2)%4; + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); + } +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + uint16_t kmask = il&1 ? 0xF0 : 0x0F; + float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint8_t m = 1<<(il*2); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + } +#endif +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + +#if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il%4; + const uchar4 sc = get_scale_min_k4(is, xb->scales); + const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; + const float ml = il<2 ? min * sc[1] : min * sc[3]; +#else + q = q + 16 * (il&1); + device const uint8_t * s = xb->scales; + device const half2 * dh = (device const half2 *)xb->d; + const float2 d = (float2)dh[0]; + const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); +#endif + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + +#if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il%4; + const uchar4 sc = get_scale_min_k4(is, xb->scales); + const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; + const float ml = il<2 ? min * sc[1] : min * sc[3]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +#else + q = q + 16 * (il&1); + device const int8_t * s = xb->scales; + const float dl = xb->d * s[il]; + uint8_t m = 1<<(il*2); + const float coef = il<2 ? 1.f : 1.f/16.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); + } +#endif +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const float d_all = (float)(xb->d); + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2)%4; +#else + ql = ql + 16 * (il&1); + float sc = scales[il]; +#endif + for (int i = 0; i < 16; ++i) { + uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \ + ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef; + reg[i/4][i%4] = d_all * sc * q * coef; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tptg[[threads_per_threadgroup]]) { + const int i = tgpig; + const int r = ((device int32_t *) src1)[i]; + + for (int ind = tiitg; ind < ne00/16; ind += tptg) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + } +} + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm(device const uchar * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ + + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + //load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ + = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + //load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg==0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +#if QK_K == 256 +#define QK_NL 16 +#else +#define QK_NL 4 +#endif + +typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ + constant uint64_t &, constant uint64_t &, uint, uint, uint); + +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + +typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ + constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ + constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/llama.cpp b/llama.cpp index c8ab313d9..a161f1566 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1845,7 +1845,7 @@ static bool llama_eval_internal( #endif #ifdef GGML_USE_METAL - if (lctx.ctx_metal && N == 1) { + if (lctx.ctx_metal) { // TODO: disabled until #2413 is resolved //if (!ggml_metal_if_optimized(lctx.ctx_metal)) { // ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf); @@ -1857,22 +1857,6 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, embeddings); } } else { - // IMPORTANT: - // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla - // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX - // coprocessor. - // - // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. - // But for now, we have focused only on Matrix x Vector Metal multiplication. - // - // TODO: avoid these syncs via shared memory (ref #1696) - // - if (lctx.ctx_metal) { - // We need to sync the GPU KV cache with the CPU KV cache - ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k); - ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); - } - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } #else From fc8ef549e50087762a0b4f901cd74b2defcc6ae3 Mon Sep 17 00:00:00 2001 From: Shouzheng Liu Date: Wed, 16 Aug 2023 16:08:28 -0400 Subject: [PATCH 04/10] metal : enable ggml-alloc (#2627) * metal: enable ggml-alloc Make ggml-alloc work with concurrently dispatch. * style-fix Co-authored-by: slaren --------- Co-authored-by: slaren Co-authored-by: Georgi Gerganov --- ggml-alloc.c | 25 ++++++++++++++++++++++++- ggml-alloc.h | 4 ++++ ggml-metal.h | 9 ++++++--- ggml-metal.m | 15 ++++++++------- llama.cpp | 34 +++++++++++++++++++--------------- 5 files changed, 61 insertions(+), 26 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index 4121f3dba..8de28cf9d 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -67,6 +67,8 @@ struct ggml_allocr { struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE]; size_t max_size; bool measure; + int parse_seq[GGML_MAX_NODES]; + bool has_parse_seq; #ifdef GGML_ALLOCATOR_DEBUG struct ggml_tensor * allocated_tensors[1024]; @@ -229,6 +231,17 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t alloc->n_free_blocks++; } +void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) { + int pos = 0; + for (int i = 0; i < n; i++) { + if (list[i] != -1) { + alloc->parse_seq[pos] = list[i]; + pos++; + } + } + alloc->has_parse_seq = true; +} + void ggml_allocr_reset(struct ggml_allocr * alloc) { alloc->n_free_blocks = 1; size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); @@ -248,6 +261,8 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) /*.hash_table = */ {{0}}, /*.max_size = */ 0, /*.measure = */ false, + /*.parse_seq = */ {0}, + /*.has_parse_seq = */ false, #ifdef GGML_ALLOCATOR_DEBUG /*.allocated_tensors = */ = {0}, #endif @@ -275,6 +290,8 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { /*.hash_table = */ {{0}}, /*.max_size = */ 0, /*.measure = */ true, + /*.parse_seq = */ {0}, + /*.has_parse_seq = */ false, #ifdef GGML_ALLOCATOR_DEBUG /*.allocated_tensors = */ = {0}, #endif @@ -473,7 +490,13 @@ static size_t ggml_allocator_alloc_graph_tensors_n( allocate_node(alloc, input); } } - for (int i = 0; i < gf->n_nodes; i++) { + for (int ind = 0; ind < gf->n_nodes; ind++) { + int i; + if (alloc->has_parse_seq) { + i = alloc->parse_seq[ind]; + } else { + i = ind; + } struct ggml_tensor * node = gf->nodes[i]; // allocate parents (leafs) diff --git a/ggml-alloc.h b/ggml-alloc.h index a5ec8f87a..14a4350ac 100644 --- a/ggml-alloc.h +++ b/ggml-alloc.h @@ -10,6 +10,10 @@ extern "C" { GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); +// tell the allocator to parse nodes following the order described in the list +// you should call this if your graph are optimized to execute out-of-order +GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n); + GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc); diff --git a/ggml-metal.h b/ggml-metal.h index 16f1a0caa..bf3f9a6a8 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * // try to find operations that can be run concurrently in the graph // you should run it again if the topology of your graph changes -void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); +void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem); -// if the graph has been optimized for concurrently dispatch -bool ggml_metal_if_optimized(struct ggml_metal_context * ctx); +// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized +int ggml_metal_if_optimized(struct ggml_metal_context * ctx); + +// output the concur_list for ggml_alloc +int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel diff --git a/ggml-metal.m b/ggml-metal.m index e13cb4b3c..32c6e4869 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -236,11 +236,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { ctx->n_cb = n_cb; } -bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) { - if (ctx->concur_list_len) { - return true; - } - return false; +int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { + return ctx->concur_list_len; +} + +int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) { + return ctx->concur_list; } // finds the Metal buffer that contains the tensor data on the GPU device @@ -383,7 +384,7 @@ void ggml_metal_get_tensor( void ggml_metal_graph_find_concurrency( struct ggml_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_cgraph * gf, bool check_mem) { int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time int nodes_unused[GGML_MAX_CONCUR]; @@ -430,7 +431,7 @@ void ggml_metal_graph_find_concurrency( } } } - if (exe_flag) { + if (exe_flag && check_mem) { // check if nodes[i]'s data will be overwritten by a node before nodes[i]. // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] int64_t data_start = (int64_t) gf->nodes[i]->data; diff --git a/llama.cpp b/llama.cpp index a161f1566..345243990 100644 --- a/llama.cpp +++ b/llama.cpp @@ -63,7 +63,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text, #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) -#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL) +#if !defined(GGML_USE_CUBLAS) #include "ggml-alloc.h" #define LLAMA_USE_ALLOCATOR #else @@ -1846,10 +1846,6 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal) { - // TODO: disabled until #2413 is resolved - //if (!ggml_metal_if_optimized(lctx.ctx_metal)) { - // ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf); - //} ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_get_tensor (lctx.ctx_metal, res); @@ -3287,7 +3283,18 @@ struct llama_context * llama_new_context_with_model( int n_past = hparams.n_ctx - n_tokens; llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); - +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + ctx->ctx_metal = ggml_metal_init(1); + if (!ctx->ctx_metal) { + LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); + llama_free(ctx); + return NULL; + } + ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif // measure memory requirements for the graph size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; @@ -3305,6 +3312,11 @@ struct llama_context * llama_new_context_with_model( ctx->buf_alloc.resize(alloc_size); ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment); +#ifdef GGML_USE_METAL + if (ctx->ctx_metal) { + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif } #else ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); @@ -3319,13 +3331,6 @@ struct llama_context * llama_new_context_with_model( #ifdef GGML_USE_METAL if (params.n_gpu_layers > 0) { // this allocates all Metal resources and memory buffers - ctx->ctx_metal = ggml_metal_init(1); - - if (!ctx->ctx_metal) { - LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); - llama_free(ctx); - return NULL; - } void * data_ptr = NULL; size_t data_size = 0; @@ -3354,8 +3359,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.addr, ctx->buf_alloc.size, 0)); #undef LLAMA_METAL_CHECK_BUF } #endif From ed53db86c3b0e0815331a96d7a379edb5e62472c Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 17 Aug 2023 04:09:03 +0800 Subject: [PATCH 05/10] metal : print error of load pipeline state (#2564) * metal : print error of load pipeline state * metal : return null if load pipeline failed --- ggml-metal.m | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 32c6e4869..d23fff1dd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -163,10 +163,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { // load kernels { + NSError * error = nil; #define GGML_METAL_ADD_KERNEL(name) \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \ - fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); + ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ + fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \ + if (error) { \ + fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + return NULL; \ + } GGML_METAL_ADD_KERNEL(add); GGML_METAL_ADD_KERNEL(add_row); From 0919a0f73d95cfb93a1646a1d1741a0615fe2c5e Mon Sep 17 00:00:00 2001 From: Kolen Cheung Date: Wed, 16 Aug 2023 21:09:49 +0100 Subject: [PATCH 06/10] cmake : install ggml-meta.metal if LLAMA_METAL (#2449) --- CMakeLists.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 01b40c2e8..824d9f2cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -569,6 +569,16 @@ install( WORLD_READ WORLD_EXECUTE DESTINATION ${CMAKE_INSTALL_BINDIR}) +if (LLAMA_METAL) + install( + FILES ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() # # programs, examples and tests From a872a2b28eaefc8d464eaa535c94deeb501666f9 Mon Sep 17 00:00:00 2001 From: Shouzheng Liu Date: Thu, 17 Aug 2023 03:35:53 -0400 Subject: [PATCH 07/10] ggml-alloc : fix discrepency between measure&eval (#2639) The GGML memory allocator consistently places a tensor within the optimal-fit memory block, which is the smallest block capable of accommodating the tensor's size. During the measurement phase, the final block is generously sized, ensuring it never qualifies as the optimal-fit block as long as there exists another block capable of accommodating the tensor. Nevertheless, in the evaluation phase, the last block is constrained in size and could potentially qualify as the optimal-fit block. Consequently, there exists the possibility of a tensor being allocated to a different region during evaluation, leading to more memory fragmentation in our scratch buffer. This recent commit guarantees uniform behavior of the allocator across both the measurement and evaluation phases, eliminating discrepancies between the two. --- ggml-alloc.c | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index 8de28cf9d..3ee98d03d 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -113,10 +113,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) size_t max_avail = 0; - // find the best fitting free block + // find the best fitting free block besides the last block int best_fit_block = -1; size_t best_fit_size = SIZE_MAX; - for (int i = 0; i < alloc->n_free_blocks; i++) { + for (int i = 0; i < alloc->n_free_blocks - 1; i++) { struct free_block * block = &alloc->free_blocks[i]; max_avail = MAX(max_avail, block->size); if (block->size >= size && block->size <= best_fit_size) { @@ -128,10 +128,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) AT_PRINTF("block %d\n", best_fit_block); if (best_fit_block == -1) { - fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", - __func__, size, max_avail); - GGML_ASSERT(!"not enough space in the buffer"); + // the last block is our last resort + struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + if (block->size >= size) { + best_fit_block = alloc->n_free_blocks - 1; + max_avail = MAX(max_avail, block->size); + } else { + fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", + __func__, size, max_avail); + GGML_ASSERT(!"not enough space in the buffer"); return; + } } struct free_block * block = &alloc->free_blocks[best_fit_block]; void * addr = block->addr; From 7cf54e1f746941279d81d485796777c01f88049c Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 17 Aug 2023 03:41:01 -0400 Subject: [PATCH 08/10] tests : adds simple llama grammar tests (#2618) * adds simple llama grammar tests * fix lint and add Makefile * 0 terminate code_points * avoid dangling pointers in candidate cleanup * cleanup grammar at end of test --- Makefile | 5 +- tests/CMakeLists.txt | 1 + tests/test-llama-grammar.cpp | 403 +++++++++++++++++++++++++++++++++++ 3 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 tests/test-llama-grammar.cpp diff --git a/Makefile b/Makefile index 5b801d16f..376a091dc 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test # Binaries only useful for tests -TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 +TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 default: $(BUILD_TARGETS) @@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) +tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 689fb6f2a..276f39b3b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,5 +12,6 @@ llama_add_test(test-quantize-perf.cpp) llama_add_test(test-sampling.cpp) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) +llama_add_test(test-llama-grammar.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/common.cpp) llama_add_test(test-grad0.cpp) # SLOW # llama_add_test(test-opt.cpp) # SLOW diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp new file mode 100644 index 000000000..f98c6531f --- /dev/null +++ b/tests/test-llama-grammar.cpp @@ -0,0 +1,403 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "llama.cpp" +#include "examples/common.cpp" +#include "examples/grammar-parser.cpp" +#include + +int main() +{ + grammar_parser::parse_state parsed_grammar; + + std::vector> expected = { + {"expr", 2}, + {"expr_6", 6}, + {"expr_7", 7}, + {"ident", 8}, + {"ident_10", 10}, + {"num", 9}, + {"num_11", 11}, + {"root", 0}, + {"root_1", 1}, + {"root_5", 5}, + {"term", 4}, + {"ws", 3}, + {"ws_12", 12}, + }; + + std::vector> expected_rules = { + {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_CHAR, 10}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}}, + {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_RULE_REF, 8}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_RULE_REF, 9}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 40}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 41}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 45}, + {LLAMA_GRETYPE_CHAR_ALT, 43}, + {LLAMA_GRETYPE_CHAR_ALT, 42}, + {LLAMA_GRETYPE_CHAR_ALT, 47}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_CHAR_ALT, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_CHAR_ALT, 95}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }, + { + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_RULE_REF, 11}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_END, 0}, + }, + { + {LLAMA_GRETYPE_CHAR, 32}, + {LLAMA_GRETYPE_CHAR_ALT, 9}, + {LLAMA_GRETYPE_CHAR_ALT, 10}, + {LLAMA_GRETYPE_RULE_REF, 12}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }, + }; + + for (auto pair : expected) + { + parsed_grammar.symbol_ids[pair.first] = pair.second; + } + + for (auto rule : expected_rules) + { + parsed_grammar.rules.push_back({}); + for (auto element : rule) + { + parsed_grammar.rules.back().push_back(element); + } + } + + llama_grammar *grammar = NULL; + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + std::vector> expected_stacks = { + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 97}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 97}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }}; + + auto index = 0; + for (auto stack : grammar->stacks) + { + // compare stack to expected_stack + for (uint32_t i = 0; i < stack.size(); i++) + { + auto element = stack[i]; + auto expected_element = expected_stacks[index][i]; + + // pretty print error message before asserting + if (expected_element.type != element->type || expected_element.value != element->value) + { + fprintf(stderr, "index: %d\n", index); + fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value); + fprintf(stderr, "expected_element != actual_element\n"); + } + + assert(expected_element.type == element->type && expected_element.value == element->value); + } + index++; + } + + std::vector> next_stacks; + std::vector next_candidates; + next_candidates.resize(24); + + for (size_t i = 0; i < 24; ++i) + { + uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point + cp[0] = 37 + i; + cp[1] = 0; + next_candidates[i] = {i, cp}; + } + + std::vector>> expected_reject = { + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + }; + + std::vector rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates); + + std::vector> all_rejects; + + for (std::size_t count = 0; count < grammar->stacks.size(); ++count) + { + rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates); + all_rejects.push_back(rejects); + } + + index = 0; + for (auto rej : all_rejects) + { + for (uint32_t i = 0; i < rej.size(); i++) + { + auto element = rej[i]; + auto expected_element = expected_reject[index][i]; + assert(element.index == expected_element.first && *element.code_points == expected_element.second); + } + index++; + } + + for (auto &candidate : next_candidates) + { + delete[] candidate.code_points; + candidate.code_points = nullptr; + } + delete grammar; + return 0; +} From a73ccf1aa34de49f61bfeb7f8a679c3bfdb3abe3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Aug 2023 10:47:09 +0300 Subject: [PATCH 09/10] llama : replace (permute + reshape + view_1d) with (view_3d) (#2538) ggml-ci --- llama.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index 345243990..b8cc22942 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1609,11 +1609,11 @@ static struct ggml_cgraph * llama_build_graph( ggml_set_name(Q, "Q"); struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa), - n_embd_head, n_head_kv, n_past + N), - 0, 2, 1, 3); + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); offload_func_kq(K); ggml_set_name(K, "K"); @@ -1642,9 +1642,9 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, n_past + N, n_embd_head, n_head_kv, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_embd_head, - n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il); + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); From 8dae7ce68437faf1fa96ec0e7687b8700956ef20 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Thu, 17 Aug 2023 07:29:44 -0600 Subject: [PATCH 10/10] Add --cfg-negative-prompt-file option for examples (#2591) Add --cfg-negative-prompt-file option for examples --- examples/common.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 9f8aab9a2..bd39d9220 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); + if (params.cfg_negative_prompt.back() == '\n') { + params.cfg_negative_prompt.pop_back(); + } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; @@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); - fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); + fprintf(stdout, " --cfg-negative-prompt PROMPT\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stdout, " --cfg-negative-prompt-file FNAME\n"); + fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);