2023-06-04 22:34:30 +02:00
|
|
|
#import "ggml-metal.h"
|
|
|
|
|
2024-09-12 13:23:49 +02:00
|
|
|
#import "ggml-impl.h"
|
2023-11-13 13:16:23 +01:00
|
|
|
#import "ggml-backend-impl.h"
|
2023-06-04 22:34:30 +02:00
|
|
|
|
|
|
|
#import <Foundation/Foundation.h>
|
|
|
|
|
|
|
|
#import <Metal/Metal.h>
|
|
|
|
|
2023-08-07 09:52:57 +02:00
|
|
|
#undef MIN
|
|
|
|
#undef MAX
|
|
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// max memory buffers that can be mapped to the device
|
|
|
|
#define GGML_METAL_MAX_BUFFERS 64
|
|
|
|
|
|
|
|
// max number of MTLCommandBuffer used to submit a graph for processing
|
|
|
|
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
#define UNUSED(x) (void)(x)
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
// globals
|
|
|
|
|
|
|
|
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
|
|
|
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
|
|
|
|
|
|
|
// initialized in ggml_backend_metal_reg
|
|
|
|
static struct ggml_backend_reg g_ggml_backend_metal_reg;
|
|
|
|
static struct ggml_backend_device g_ggml_backend_metal_device;
|
|
|
|
|
|
|
|
// information about a Metal device
|
|
|
|
// note: assumes single GPU device - the default one
|
|
|
|
// TODO: support multiple GPU devices
|
|
|
|
static struct ggml_backend_metal_device_context {
|
|
|
|
id<MTLDevice> mtl_device;
|
|
|
|
int mtl_device_ref_count;
|
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
bool has_simdgroup_reduction;
|
|
|
|
bool has_simdgroup_mm;
|
|
|
|
bool has_bfloat;
|
2024-11-08 20:59:46 +01:00
|
|
|
bool use_bfloat;
|
2024-10-07 17:27:51 +02:00
|
|
|
|
|
|
|
char name[128];
|
|
|
|
} g_ggml_ctx_dev_main = {
|
2024-11-06 18:53:51 +01:00
|
|
|
/*.mtl_device =*/ nil,
|
|
|
|
/*.mtl_device_ref_count =*/ 0,
|
|
|
|
/*.has_simdgroup_reduction =*/ false,
|
|
|
|
/*.has_simdgroup_mm =*/ false,
|
|
|
|
/*.has_bfloat =*/ false,
|
2024-11-08 20:59:46 +01:00
|
|
|
/*.use_bfloat =*/ false,
|
2024-11-06 18:53:51 +01:00
|
|
|
/*.name =*/ "",
|
2024-10-07 17:27:51 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
// acquire
|
|
|
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
|
|
|
assert(ctx != NULL);
|
|
|
|
|
|
|
|
if (ctx->mtl_device == nil) {
|
|
|
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
|
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
2024-10-07 17:27:51 +02:00
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
|
|
|
|
|
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
|
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
2024-10-07 17:27:51 +02:00
|
|
|
|
2024-11-08 20:59:46 +01:00
|
|
|
#if defined(GGML_METAL_USE_BF16)
|
|
|
|
ctx->use_bfloat = ctx->has_bfloat;
|
|
|
|
#else
|
|
|
|
ctx->use_bfloat = false;
|
|
|
|
#endif
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx->mtl_device_ref_count++;
|
|
|
|
|
|
|
|
return ctx->mtl_device;
|
|
|
|
}
|
|
|
|
|
|
|
|
// release
|
|
|
|
static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
|
|
|
|
assert(ctx != NULL);
|
|
|
|
assert(ctx->mtl_device_ref_count > 0);
|
|
|
|
|
|
|
|
ctx->mtl_device_ref_count--;
|
|
|
|
|
|
|
|
if (ctx->mtl_device_ref_count == 0) {
|
|
|
|
[ctx->mtl_device release];
|
|
|
|
ctx->mtl_device = nil;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// kernels
|
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
struct ggml_metal_kernel {
|
|
|
|
id<MTLComputePipelineState> pipeline;
|
|
|
|
};
|
|
|
|
|
|
|
|
enum ggml_metal_kernel_type {
|
|
|
|
GGML_METAL_KERNEL_TYPE_ADD,
|
|
|
|
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
2024-08-27 21:01:45 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SUB,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
|
|
|
GGML_METAL_KERNEL_TYPE_DIV,
|
|
|
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
2024-05-27 11:10:19 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
2024-04-14 13:14:19 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_TANH,
|
|
|
|
GGML_METAL_KERNEL_TYPE_RELU,
|
2024-05-01 23:44:26 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GELU,
|
2024-04-16 17:40:48 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
2024-04-16 17:40:48 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_SILU,
|
2024-04-16 17:40:48 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
|
|
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
2024-01-30 14:14:12 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
2024-02-24 15:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
2024-02-26 17:28:38 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
2024-02-18 17:16:55 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
2024-03-26 15:21:27 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
|
2024-02-21 10:39:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
2024-02-27 15:34:24 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
|
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
|
|
GGML_METAL_KERNEL_TYPE_NORM,
|
2024-08-26 16:55:36 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
2024-01-30 14:14:12 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
2024-02-24 15:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
2024-02-26 17:28:38 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
2024-02-18 17:16:55 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
2024-03-26 15:21:27 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
|
2024-02-21 10:39:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
2024-02-27 15:34:24 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
|
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
|
|
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
2024-11-06 18:53:51 +01:00
|
|
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
2024-01-30 14:14:12 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
2024-02-24 15:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
2024-02-26 17:28:38 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
2024-02-18 17:16:55 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
2024-03-26 15:21:27 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
2024-02-21 10:39:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
2024-02-27 15:34:24 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
2024-01-30 14:14:12 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
2024-02-24 15:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
2024-02-26 17:28:38 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
2024-02-18 17:16:55 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
2024-03-26 15:21:27 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
2024-02-21 10:39:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
2024-02-27 15:34:24 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
2024-01-30 14:14:12 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
2024-02-24 15:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
2024-02-26 17:28:38 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
2024-02-18 17:16:55 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
2024-03-26 15:21:27 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
2024-02-21 10:39:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
2024-02-27 15:34:24 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
2024-06-05 10:29:20 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
2024-01-31 14:35:41 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
2024-10-23 12:33:45 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
2024-03-03 13:23:52 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
|
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
|
|
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
2024-11-06 09:24:23 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
2024-11-08 12:47:22 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
2024-11-06 09:24:23 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
2024-11-08 12:47:22 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
2024-11-06 09:24:23 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
2024-11-08 12:47:22 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
2024-11-06 09:24:23 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
2024-07-13 17:32:33 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
2024-07-13 17:32:33 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
2024-03-21 08:27:57 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SQR,
|
2024-08-27 21:01:45 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_SQRT,
|
|
|
|
GGML_METAL_KERNEL_TYPE_SIN,
|
|
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
2024-01-13 17:03:45 +01:00
|
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
2024-10-23 12:33:45 +02:00
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
2024-01-13 17:03:45 +01:00
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_COUNT
|
|
|
|
};
|
|
|
|
|
2024-08-07 08:57:00 +02:00
|
|
|
struct ggml_backend_metal_context {
|
2023-06-04 22:34:30 +02:00
|
|
|
id<MTLCommandQueue> queue;
|
|
|
|
|
2023-08-28 09:59:08 +02:00
|
|
|
dispatch_queue_t d_queue;
|
|
|
|
|
2024-01-28 20:50:16 +01:00
|
|
|
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// capture state
|
|
|
|
bool capture_next_compute;
|
|
|
|
bool capture_started;
|
|
|
|
|
|
|
|
id<MTLCaptureScope> capture_scope;
|
|
|
|
|
|
|
|
// command buffer state
|
|
|
|
int n_cb; // number of extra threads used to submit the command buffers
|
|
|
|
int n_nodes_0; // number of nodes submitted by the main thread
|
|
|
|
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
|
|
|
|
int n_nodes_per_cb;
|
|
|
|
|
|
|
|
struct ggml_cgraph * gf;
|
|
|
|
|
|
|
|
// the callback given to the thread pool
|
|
|
|
void (^encode_async)(size_t ith);
|
|
|
|
|
|
|
|
// n_cb command buffers + 1 used by the main thread
|
|
|
|
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
2024-08-07 08:55:49 +02:00
|
|
|
|
|
|
|
// abort ggml_metal_graph_compute if callback returns true
|
|
|
|
ggml_abort_callback abort_callback;
|
|
|
|
void * abort_callback_data;
|
2023-06-04 22:34:30 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
// MSL code
|
|
|
|
// TODO: move the contents here when ready
|
|
|
|
// for now it is easier to work in a separate file
|
2024-02-17 22:03:14 +01:00
|
|
|
// static NSString * const msl_library_source = @"see metal.metal";
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2023-06-10 16:47:34 +02:00
|
|
|
// Here to assist with NSBundle Path Hack
|
|
|
|
@interface GGMLMetalClass : NSObject
|
|
|
|
@end
|
|
|
|
@implementation GGMLMetalClass
|
|
|
|
@end
|
|
|
|
|
2024-01-13 19:45:45 +01:00
|
|
|
static void * ggml_metal_host_malloc(size_t n) {
|
|
|
|
void * data = NULL;
|
2024-05-08 21:08:10 +02:00
|
|
|
|
|
|
|
#if TARGET_OS_OSX
|
|
|
|
kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
|
|
|
|
if (err != KERN_SUCCESS) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
|
2024-05-08 21:08:10 +02:00
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
#else
|
2024-01-13 19:45:45 +01:00
|
|
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
|
|
if (result != 0) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
2024-01-13 19:45:45 +01:00
|
|
|
return NULL;
|
|
|
|
}
|
2024-05-08 21:08:10 +02:00
|
|
|
#endif
|
2024-01-13 19:45:45 +01:00
|
|
|
|
|
|
|
return data;
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: allocating\n", __func__);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-01-18 09:47:24 +01:00
|
|
|
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
2023-09-09 10:46:04 +02:00
|
|
|
// Show all the Metal device instances in the system
|
|
|
|
NSArray * devices = MTLCopyAllDevices();
|
2024-01-18 09:47:24 +01:00
|
|
|
for (id<MTLDevice> device in devices) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
2023-09-02 14:29:09 +02:00
|
|
|
}
|
2024-01-18 09:47:24 +01:00
|
|
|
[devices release]; // since it was created by a *Copy* C method
|
2023-09-09 10:46:04 +02:00
|
|
|
#endif
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
// init context
|
|
|
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
|
|
|
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
2023-09-02 14:29:09 +02:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
ctx->queue = [device newCommandQueue];
|
2023-09-15 10:09:24 +02:00
|
|
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-01-28 20:50:16 +01:00
|
|
|
id<MTLLibrary> metal_library;
|
|
|
|
|
2023-10-07 10:40:27 +02:00
|
|
|
// load library
|
2024-03-14 10:55:23 +01:00
|
|
|
//
|
|
|
|
// - first check if the library is embedded
|
|
|
|
// - then check if the library is in the bundle
|
|
|
|
// - if not found, load the source and compile it
|
|
|
|
// - if that fails, return NULL
|
2023-06-04 22:34:30 +02:00
|
|
|
{
|
2023-10-07 10:40:27 +02:00
|
|
|
NSBundle * bundle = nil;
|
|
|
|
#ifdef SWIFT_PACKAGE
|
|
|
|
bundle = SWIFTPM_MODULE_BUNDLE;
|
2023-06-04 22:34:30 +02:00
|
|
|
#else
|
2023-10-07 10:40:27 +02:00
|
|
|
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
|
|
|
#endif
|
2024-03-14 10:55:23 +01:00
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
NSError * error = nil;
|
2024-03-14 10:55:23 +01:00
|
|
|
|
|
|
|
#if GGML_METAL_EMBED_LIBRARY
|
|
|
|
const bool try_metallib = false;
|
|
|
|
#else
|
|
|
|
const bool try_metallib = true;
|
|
|
|
#endif
|
|
|
|
|
|
|
|
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
|
|
if (try_metallib && path_lib != nil) {
|
2024-01-02 09:57:44 +01:00
|
|
|
// pre-compiled library found
|
2024-03-14 10:55:23 +01:00
|
|
|
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
2024-03-14 10:55:23 +01:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
metal_library = [device newLibraryWithURL:libURL error:&error];
|
2024-01-25 10:26:17 +01:00
|
|
|
if (error) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
2024-01-25 10:26:17 +01:00
|
|
|
return NULL;
|
|
|
|
}
|
2023-10-07 10:40:27 +02:00
|
|
|
} else {
|
2024-02-11 15:41:41 +01:00
|
|
|
#if GGML_METAL_EMBED_LIBRARY
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
2024-02-11 15:41:41 +01:00
|
|
|
|
|
|
|
extern const char ggml_metallib_start[];
|
|
|
|
extern const char ggml_metallib_end[];
|
|
|
|
|
2024-03-14 10:55:23 +01:00
|
|
|
NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
2024-02-11 15:41:41 +01:00
|
|
|
#else
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-03-14 10:55:23 +01:00
|
|
|
NSString * path_source;
|
|
|
|
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-03-14 10:55:23 +01:00
|
|
|
if (path_resource) {
|
|
|
|
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
2023-11-13 15:55:52 +01:00
|
|
|
} else {
|
2024-03-14 10:55:23 +01:00
|
|
|
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
2023-11-13 15:55:52 +01:00
|
|
|
}
|
2024-03-14 10:55:23 +01:00
|
|
|
|
|
|
|
if (path_source == nil) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
2024-03-14 10:55:23 +01:00
|
|
|
path_source = @"ggml-metal.metal";
|
2023-10-28 14:43:01 +02:00
|
|
|
}
|
2024-03-14 10:55:23 +01:00
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
2024-03-14 10:55:23 +01:00
|
|
|
|
|
|
|
NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
2023-10-07 10:40:27 +02:00
|
|
|
if (error) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
2023-10-07 10:40:27 +02:00
|
|
|
return NULL;
|
|
|
|
}
|
2024-03-14 10:55:23 +01:00
|
|
|
#endif // GGML_METAL_EMBED_LIBRARY
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-01-17 17:38:39 +01:00
|
|
|
@autoreleasepool {
|
|
|
|
// dictionary of preprocessor macros
|
|
|
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-11-08 20:59:46 +01:00
|
|
|
if (ctx_dev->use_bfloat) {
|
|
|
|
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
|
|
|
}
|
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
2024-01-17 17:38:39 +01:00
|
|
|
options.preprocessorMacros = prep;
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-01-17 17:38:39 +01:00
|
|
|
//[options setFastMathEnabled:false];
|
2024-01-02 09:57:44 +01:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
2024-01-25 10:26:17 +01:00
|
|
|
if (error) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
2024-01-25 10:26:17 +01:00
|
|
|
return NULL;
|
|
|
|
}
|
2024-11-01 15:55:10 +01:00
|
|
|
|
|
|
|
#if !__has_feature(objc_arc)
|
|
|
|
[options release];
|
|
|
|
#endif
|
2024-01-17 17:38:39 +01:00
|
|
|
}
|
2024-11-01 15:55:10 +01:00
|
|
|
#if GGML_METAL_EMBED_LIBRARY
|
|
|
|
[src release];
|
|
|
|
#endif // GGML_METAL_EMBED_LIBRARY
|
2023-10-07 10:40:27 +02:00
|
|
|
}
|
2023-06-04 22:34:30 +02:00
|
|
|
}
|
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
// print MTL GPU family:
|
2024-10-07 17:27:51 +02:00
|
|
|
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
// determine max supported GPU family
|
|
|
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
|
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
2024-01-13 17:03:45 +01:00
|
|
|
{
|
|
|
|
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
2024-10-07 17:27:51 +02:00
|
|
|
if ([device supportsFamily:i]) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
2024-01-13 17:03:45 +01:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
2024-10-07 17:27:51 +02:00
|
|
|
if ([device supportsFamily:i]) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
2024-01-13 17:03:45 +01:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
|
|
|
|
if ([device supportsFamily:i]) {
|
|
|
|
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
|
2024-01-13 17:03:45 +01:00
|
|
|
break;
|
|
|
|
}
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
|
|
|
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
|
|
|
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
2024-01-16 14:33:02 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
ctx->capture_next_compute = false;
|
|
|
|
ctx->capture_started = false;
|
|
|
|
ctx->capture_scope = nil;
|
|
|
|
|
|
|
|
ctx->gf = nil;
|
|
|
|
ctx->encode_async = nil;
|
|
|
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
|
|
|
ctx->command_buffers[i] = nil;
|
|
|
|
}
|
2024-01-29 12:29:46 +01:00
|
|
|
|
2024-01-16 14:33:02 +01:00
|
|
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
|
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
2024-10-07 17:27:51 +02:00
|
|
|
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
// load kernels
|
|
|
|
{
|
2023-08-16 22:09:03 +02:00
|
|
|
NSError * error = nil;
|
2023-11-01 07:04:02 +01:00
|
|
|
|
2024-01-28 20:50:16 +01:00
|
|
|
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
2024-01-13 17:03:45 +01:00
|
|
|
ctx->kernels[i].pipeline = nil;
|
|
|
|
}
|
|
|
|
|
|
|
|
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
|
|
|
|
if (supported) { \
|
|
|
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
2024-01-28 20:50:16 +01:00
|
|
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
2024-10-07 17:27:51 +02:00
|
|
|
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
2024-11-09 10:21:49 +01:00
|
|
|
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
2024-11-08 12:47:22 +01:00
|
|
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
|
|
|
(int) kernel->pipeline.threadExecutionWidth); \
|
2024-01-28 20:50:16 +01:00
|
|
|
[metal_function release]; \
|
2024-01-13 17:03:45 +01:00
|
|
|
if (error) { \
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
2024-01-28 20:50:16 +01:00
|
|
|
[metal_library release]; \
|
2024-01-13 17:03:45 +01:00
|
|
|
return NULL; \
|
|
|
|
} \
|
|
|
|
} else { \
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
2023-08-16 22:09:03 +02:00
|
|
|
}
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-11-06 18:53:51 +01:00
|
|
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
|
|
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
2024-11-08 20:59:46 +01:00
|
|
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
2024-10-07 17:27:51 +02:00
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
|
|
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
2024-08-27 21:01:45 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
2024-05-27 11:10:19 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
2024-05-01 23:44:26 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
2024-08-26 16:55:36 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
2024-06-05 10:29:20 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
2024-10-23 12:33:45 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
2024-07-13 17:32:33 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
2024-11-06 18:53:51 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
2024-11-08 20:59:46 +01:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
2024-08-27 21:01:45 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
2024-10-23 12:33:45 +02:00
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
2023-06-04 22:34:30 +02:00
|
|
|
}
|
|
|
|
|
2024-01-28 20:50:16 +01:00
|
|
|
[metal_library release];
|
2024-10-01 15:00:25 +02:00
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
return ctx;
|
|
|
|
}
|
|
|
|
|
2024-08-07 08:57:00 +02:00
|
|
|
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
2023-08-28 09:59:08 +02:00
|
|
|
|
2024-01-28 20:50:16 +01:00
|
|
|
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
|
|
|
[ctx->kernels[i].pipeline release];
|
2024-01-13 17:03:45 +01:00
|
|
|
}
|
|
|
|
|
2024-10-07 14:26:31 +02:00
|
|
|
Block_release(ctx->encode_async);
|
|
|
|
|
2023-08-28 09:59:08 +02:00
|
|
|
[ctx->queue release];
|
|
|
|
|
|
|
|
dispatch_release(ctx->d_queue);
|
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
free(ctx);
|
|
|
|
}
|
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
// temporarily defined here for compatibility between ggml-backend and the old API
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
struct ggml_backend_metal_buffer {
|
|
|
|
void * data;
|
|
|
|
size_t size;
|
2023-12-07 21:26:54 +01:00
|
|
|
|
|
|
|
id<MTLBuffer> metal;
|
|
|
|
};
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
struct ggml_backend_metal_buffer_context {
|
|
|
|
void * all_data;
|
|
|
|
size_t all_size;
|
|
|
|
bool owned;
|
|
|
|
|
|
|
|
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
|
|
|
int n_buffers;
|
|
|
|
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
|
|
|
};
|
|
|
|
|
2023-06-04 22:34:30 +02:00
|
|
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
|
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
|
|
// Metal buffer based on the host memory pointer
|
|
|
|
//
|
2024-01-26 13:16:07 +01:00
|
|
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
|
2024-10-03 17:39:03 +02:00
|
|
|
//GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2023-06-18 08:09:47 +02:00
|
|
|
const int64_t tsize = ggml_nbytes(t);
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
|
|
|
|
2024-01-26 13:16:07 +01:00
|
|
|
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
2023-11-13 15:55:52 +01:00
|
|
|
|
2023-06-18 08:09:47 +02:00
|
|
|
// find the view that contains the tensor fully
|
2024-01-26 13:16:07 +01:00
|
|
|
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
|
|
|
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
//GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
2024-01-26 13:16:07 +01:00
|
|
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
2023-06-04 22:34:30 +02:00
|
|
|
*offs = (size_t) ioffs;
|
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
//GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-01-26 13:16:07 +01:00
|
|
|
return buf_ctx->buffers[i].metal;
|
2023-06-04 22:34:30 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
|
|
|
return nil;
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
2024-11-06 18:53:51 +01:00
|
|
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
|
|
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
2024-11-08 20:59:46 +01:00
|
|
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
2024-11-06 18:53:51 +01:00
|
|
|
|
2024-11-08 20:59:46 +01:00
|
|
|
if (!use_bfloat) {
|
2024-11-06 18:53:51 +01:00
|
|
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
|
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
|
|
return false;
|
|
|
|
}
|
2024-06-20 07:32:01 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
switch (op->op) {
|
|
|
|
case GGML_OP_UNARY:
|
|
|
|
switch (ggml_get_unary_op(op)) {
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_UNARY_OP_TANH:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_UNARY_OP_RELU:
|
2024-05-01 23:44:26 +02:00
|
|
|
case GGML_UNARY_OP_SIGMOID:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_UNARY_OP_GELU:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
|
|
case GGML_UNARY_OP_SILU:
|
2024-06-12 15:00:22 +02:00
|
|
|
return ggml_is_contiguous(op->src[0]);
|
2023-12-07 21:26:54 +01:00
|
|
|
default:
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
case GGML_OP_NONE:
|
|
|
|
case GGML_OP_RESHAPE:
|
|
|
|
case GGML_OP_VIEW:
|
2023-12-13 13:04:25 +01:00
|
|
|
case GGML_OP_TRANSPOSE:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_PERMUTE:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_CONCAT:
|
|
|
|
case GGML_OP_ADD:
|
2024-08-27 21:01:45 +02:00
|
|
|
case GGML_OP_SUB:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_ACC:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_MUL:
|
|
|
|
case GGML_OP_DIV:
|
2024-05-27 11:10:19 +02:00
|
|
|
case GGML_OP_REPEAT:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_SCALE:
|
2024-04-14 13:14:19 +02:00
|
|
|
case GGML_OP_CLAMP:
|
2024-08-27 21:01:45 +02:00
|
|
|
return true;
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_SQR:
|
2024-08-27 21:01:45 +02:00
|
|
|
case GGML_OP_SQRT:
|
|
|
|
case GGML_OP_SIN:
|
|
|
|
case GGML_OP_COS:
|
|
|
|
return ggml_is_contiguous(op->src[0]);
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_SUM_ROWS:
|
|
|
|
case GGML_OP_SOFT_MAX:
|
|
|
|
case GGML_OP_RMS_NORM:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_GROUP_NORM:
|
2024-11-06 18:53:51 +01:00
|
|
|
return has_simdgroup_reduction;
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_NORM:
|
|
|
|
case GGML_OP_ROPE:
|
2024-01-31 14:35:41 +01:00
|
|
|
return true;
|
2024-09-08 08:57:57 +02:00
|
|
|
case GGML_OP_IM2COL:
|
|
|
|
return op->src[0]->type == GGML_TYPE_F16;
|
2024-01-31 14:35:41 +01:00
|
|
|
case GGML_OP_POOL_1D:
|
|
|
|
return false;
|
2024-10-23 12:33:45 +02:00
|
|
|
case GGML_OP_POOL_2D:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_UPSCALE:
|
|
|
|
case GGML_OP_PAD:
|
2024-03-03 13:23:52 +01:00
|
|
|
case GGML_OP_ARANGE:
|
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_ARGSORT:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_LEAKY_RELU:
|
2024-01-13 17:03:45 +01:00
|
|
|
return true;
|
2024-05-10 17:20:10 +02:00
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
2024-11-06 09:24:23 +01:00
|
|
|
if (op->src[1]->type != op->src[2]->type) {
|
2024-05-27 09:38:39 +02:00
|
|
|
return false;
|
|
|
|
}
|
2024-11-06 18:53:51 +01:00
|
|
|
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
2024-08-26 16:55:36 +02:00
|
|
|
case GGML_OP_SSM_CONV:
|
|
|
|
case GGML_OP_SSM_SCAN:
|
|
|
|
return true;
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_MUL_MAT:
|
|
|
|
case GGML_OP_MUL_MAT_ID:
|
2024-11-06 18:53:51 +01:00
|
|
|
return has_simdgroup_reduction &&
|
2024-01-23 14:50:56 +01:00
|
|
|
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
2023-12-13 13:04:25 +01:00
|
|
|
case GGML_OP_CPY:
|
|
|
|
case GGML_OP_DUP:
|
|
|
|
case GGML_OP_CONT:
|
|
|
|
{
|
|
|
|
switch (op->src[0]->type) {
|
|
|
|
case GGML_TYPE_F32:
|
|
|
|
switch (op->type) {
|
|
|
|
case GGML_TYPE_F32:
|
2024-07-13 17:32:33 +02:00
|
|
|
case GGML_TYPE_F16:
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16:
|
2023-12-13 13:04:25 +01:00
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
case GGML_TYPE_Q4_1:
|
2024-03-21 08:27:57 +01:00
|
|
|
case GGML_TYPE_Q5_0:
|
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
|
case GGML_TYPE_IQ4_NL:
|
2023-12-13 13:04:25 +01:00
|
|
|
return true;
|
|
|
|
default:
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
case GGML_TYPE_F16:
|
|
|
|
switch (op->type) {
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_F32:
|
|
|
|
case GGML_TYPE_F16:
|
2023-12-13 13:04:25 +01:00
|
|
|
return true;
|
2024-11-06 18:53:51 +01:00
|
|
|
default:
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
switch (op->type) {
|
|
|
|
case GGML_TYPE_F32:
|
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
return true;
|
|
|
|
default:
|
2023-12-13 13:04:25 +01:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
return false;
|
|
|
|
};
|
|
|
|
}
|
2023-12-07 21:26:54 +01:00
|
|
|
case GGML_OP_DIAG_MASK_INF:
|
2023-12-13 20:54:54 +01:00
|
|
|
case GGML_OP_GET_ROWS:
|
2023-12-07 21:26:54 +01:00
|
|
|
{
|
2024-07-13 17:32:33 +02:00
|
|
|
return op->ne[3] == 1;
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|
|
|
|
default:
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
static void ggml_metal_encode_node(
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_backend_t backend,
|
2024-10-01 15:00:25 +02:00
|
|
|
int idx,
|
|
|
|
id<MTLComputeCommandEncoder> encoder) {
|
2024-10-07 17:27:51 +02:00
|
|
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
struct ggml_cgraph * gf = ctx->gf;
|
2023-08-28 09:59:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
2023-07-25 14:00:19 +02:00
|
|
|
|
2024-10-03 17:39:03 +02:00
|
|
|
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
struct ggml_tensor * src0 = node->src[0];
|
|
|
|
struct ggml_tensor * src1 = node->src[1];
|
|
|
|
struct ggml_tensor * src2 = node->src[2];
|
|
|
|
struct ggml_tensor * dst = node;
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (ggml_is_empty(dst)) {
|
|
|
|
return;
|
|
|
|
}
|
2024-01-29 10:22:23 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
switch (dst->op) {
|
|
|
|
case GGML_OP_NONE:
|
|
|
|
case GGML_OP_RESHAPE:
|
|
|
|
case GGML_OP_VIEW:
|
|
|
|
case GGML_OP_TRANSPOSE:
|
|
|
|
case GGML_OP_PERMUTE:
|
|
|
|
{
|
|
|
|
// noop -> next node
|
|
|
|
} return;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
} break;
|
|
|
|
}
|
2024-01-29 10:22:23 +01:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
if (!ggml_metal_supports_op(ctx_dev, dst)) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ABORT("unsupported op");
|
2024-01-29 10:22:23 +01:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
|
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
|
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
|
|
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
|
|
|
|
|
|
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
|
|
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
|
|
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
|
|
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
|
|
|
|
|
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
|
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
|
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
|
|
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
|
|
|
|
|
|
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
|
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
|
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
|
|
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
|
|
|
|
|
|
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
|
|
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
|
|
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
|
|
|
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
|
|
|
|
|
|
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
|
|
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
|
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2024-11-08 12:47:22 +01:00
|
|
|
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
|
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
|
|
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
|
|
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
|
|
|
|
|
|
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
|
|
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
|
|
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
|
|
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
|
|
|
|
|
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
|
|
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
|
|
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
|
|
|
|
|
|
|
size_t offs_src0 = 0;
|
|
|
|
size_t offs_src1 = 0;
|
|
|
|
size_t offs_src2 = 0;
|
|
|
|
size_t offs_dst = 0;
|
|
|
|
|
|
|
|
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
|
|
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
|
|
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
|
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
|
|
|
2024-10-25 21:26:15 +02:00
|
|
|
#if 0
|
|
|
|
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
|
|
if (src0) {
|
|
|
|
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
|
|
|
ggml_is_contiguous(src0), src0->name);
|
|
|
|
}
|
|
|
|
if (src1) {
|
|
|
|
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
|
|
ggml_is_contiguous(src1), src1->name);
|
|
|
|
}
|
|
|
|
if (dst) {
|
|
|
|
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
|
|
dst->name);
|
|
|
|
}
|
|
|
|
#endif
|
2024-10-01 15:00:25 +02:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
switch (dst->op) {
|
|
|
|
case GGML_OP_CONCAT:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
|
|
|
|
|
|
|
const int32_t dim = ((const int32_t *) dst->op_params)[0];
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
|
|
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
|
|
|
|
|
|
|
const int nth = MIN(1024, ne0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_ADD:
|
|
|
|
case GGML_OP_SUB:
|
|
|
|
case GGML_OP_MUL:
|
|
|
|
case GGML_OP_DIV:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
2023-06-04 22:34:30 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const size_t offs = 0;
|
2024-01-29 10:22:23 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
bool bcast_row = false;
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
int64_t nb = ne00; // used by the "row" kernels
|
2023-09-28 18:04:36 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
2023-12-13 13:04:25 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2023-12-13 13:04:25 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// src1 is a row
|
|
|
|
GGML_ASSERT(ne11 == 1);
|
2023-09-28 18:04:36 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
nb = ne00 / 4;
|
|
|
|
switch (dst->op) {
|
|
|
|
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
|
|
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
|
|
|
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
|
|
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
}
|
2023-12-13 13:04:25 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
bcast_row = true;
|
|
|
|
} else {
|
|
|
|
switch (dst->op) {
|
|
|
|
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
|
|
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
|
|
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
|
|
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
}
|
|
|
|
}
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
|
|
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
|
|
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
|
|
|
|
|
|
if (bcast_row) {
|
|
|
|
const int64_t n = ggml_nelements(dst)/4;
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} else {
|
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_OP_REPEAT:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline;
|
|
|
|
|
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
|
|
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
|
|
|
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
}
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
|
|
|
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_ACC:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
|
|
|
|
|
const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
|
|
|
|
const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
|
|
|
|
const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
|
|
|
|
const size_t offs = ((const int32_t *) dst->op_params)[3];
|
|
|
|
|
|
|
|
const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
|
|
|
|
|
|
|
|
if (!inplace) {
|
|
|
|
// run a separete kernel to cpy src->dst
|
|
|
|
// not sure how to avoid this
|
|
|
|
// TODO: make a simpler cpy_bytes kernel
|
|
|
|
|
|
|
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
|
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
}
|
llama : greatly reduce output buffer memory usage (#6122)
* llama : greatly reduce logits memory usage
* llama : more compact state saving and reloading
* llama : fix lctx.n_outputs not being set before building graph
* perplexity : adapt to the logits API changes
* perplexity : fix Winogrande, use correct logits for second choice start
The first logits used to evaluate the second choice were not from
the end of the common prefix; instead, they were the logits from the end
of the first choice. This has been corrected.
The previous implementation sometimes had outliers in the scores of
choices for some tasks, and the logic to skip choices words
in the log-likelihood evaluation probably was an attempt to reduce those,
but it was complex and didn't quite seem to be the right thing.
This is simpler now, and the outlier scores aren't there anymore.
* perplexity : normalize spaces and punctuation in Winogrande sentences
* llama : fix embedding conditions
* llama : fix llama_get_embeddings_ith when the resulting id is 0
* llama : fix wrong n_outputs in llama_set_inputs
A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.
* llama : when saving the state, recalculate n_outputs
This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.
* llama : fix not-skipping outputs of non-causal models
* llama : fix running a batch with n_outputs == 0
It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.
* llama : keep same graph topology even when n_outputs == 0
* ggml : saner ggml_can_repeat with empty tensors
* ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1
* ggml : do not multi-thread ops returning empty tensors
* ggml : make ggml_is_empty public and work with views
* llama : use a vector for ctx->output_ids
* llama : rework reallocation logic for llama_output_reserve
Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.
* ggml : skip empty tensors in all backends
* llama : fix llama_output_reserve nullptr deref when new_size is 0
* perplexity : make Winogrande work as it does on master
The problems with the Winogrande implementation will
need to be fixed in a separate PR to ease review.
* llama : clearer error messages for invalid logits or embeddings ids
* llama : assert all models that can have inp_out_ids
Since the graph topology is now constant, this presence check
can be done even when there are no outputs.
* llama : assert logits and embd buffers exist before writing to them
* llama : handle errors from llama_output_reserve at call sites
* perplexity : make hellaswag and multiple-choice outputs identical to master
Due to how the KV cache is updated, the logprobs for tokens in a batch
are very slightly affected by the other tokens present in the batch,
so to make hellaswag and multiple-choice return exactly the same results
as on master, the last token of each sequence needs to be evaluated
even though its output is not used at all.
This will probably be changed back in the future to make these benchmarks
a tiny bit faster.
* perplexity : fix division by zero when using less than 100 multiple-choice tasks
* llama : allow loading state saved with a different ctx size
When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.
Doing this enables the use-case of extending or shrinking the context size
of a saved session.
This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.
* llama : minor
ggml-ci
* readme : update recent API changes, and warn about Vulkan
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-26 15:46:41 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
|
|
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
|
|
|
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
|
|
|
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
|
|
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
|
|
|
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
|
|
|
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
|
|
|
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
|
|
|
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SCALE:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2023-09-28 18:04:36 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float scale;
|
|
|
|
memcpy(&scale, dst->op_params, sizeof(scale));
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
int64_t n = ggml_nelements(dst);
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (n % 4 == 0) {
|
|
|
|
n /= 4;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
|
|
|
|
}
|
2023-10-24 08:46:50 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_CLAMP:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
2024-05-27 11:10:19 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float min;
|
|
|
|
float max;
|
|
|
|
memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
|
|
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
|
|
|
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_UNARY:
|
|
|
|
switch (ggml_get_unary_op(node)) {
|
|
|
|
// we are not taking into account the strides, so for now require contiguous tensors
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
case GGML_UNARY_OP_TANH:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
2024-05-01 23:44:26 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_UNARY_OP_RELU:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
|
2024-05-01 23:44:26 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2024-05-01 23:44:26 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
2023-10-08 09:01:53 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_UNARY_OP_SIGMOID:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
2023-10-08 09:01:53 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2023-10-08 09:01:53 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_UNARY_OP_GELU:
|
|
|
|
{
|
|
|
|
int64_t n = ggml_nelements(dst);
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (n % 4 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
|
|
|
n /= 4;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
|
|
|
}
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
|
|
{
|
|
|
|
int64_t n = ggml_nelements(dst);
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (n % 4 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
|
|
|
n /= 4;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
|
|
|
}
|
2023-12-13 13:04:25 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_UNARY_OP_SILU:
|
|
|
|
{
|
|
|
|
int64_t n = ggml_nelements(dst);
|
2024-04-16 17:40:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
2024-01-02 20:07:47 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (n % 4 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
|
|
|
n /= 4;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
|
|
|
}
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
default:
|
|
|
|
{
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ABORT("fatal error");
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SQR:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2024-03-03 13:23:52 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
|
2024-02-17 22:04:16 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2024-03-03 13:23:52 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
2024-02-17 22:04:16 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SQRT:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2023-06-17 16:37:49 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
|
2023-06-17 16:37:49 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2023-06-17 19:24:11 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SIN:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
|
|
|
|
const int64_t n = ggml_nelements(dst);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_COS:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
|
|
|
|
const int64_t n = ggml_nelements(dst);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SUM_ROWS:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SOFT_MAX:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
int nth = 32; // SIMD width
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
|
if (ne00%4 == 0) {
|
|
|
|
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
|
|
|
nth *= 2;
|
|
|
|
}
|
|
|
|
if (use_f16) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
|
|
|
nth *= 2;
|
|
|
|
}
|
|
|
|
if (use_f16) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
float scale;
|
|
|
|
float max_bias;
|
|
|
|
|
|
|
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
|
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
|
|
|
|
|
|
const int64_t nrows_x = ggml_nrows(src0);
|
|
|
|
const int64_t nrows_y = src0->ne[1];
|
|
|
|
|
|
|
|
const uint32_t n_head = nrows_x/nrows_y;
|
|
|
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
|
|
|
|
|
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
if (id_src1) {
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
} else {
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
|
}
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
|
|
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
|
|
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
|
|
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
|
|
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
|
|
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_DIAG_MASK_INF:
|
|
|
|
{
|
|
|
|
const int n_past = ((const int32_t *)(dst->op_params))[0];
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
if (ne00%8 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
|
|
|
}
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
|
|
|
|
|
|
if (ne00%8 == 0) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SSM_CONV:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_SSM_SCAN:
|
|
|
|
{
|
|
|
|
struct ggml_tensor * src3 = node->src[3];
|
|
|
|
struct ggml_tensor * src4 = node->src[4];
|
|
|
|
struct ggml_tensor * src5 = node->src[5];
|
|
|
|
|
|
|
|
GGML_ASSERT(src3);
|
|
|
|
GGML_ASSERT(src4);
|
|
|
|
GGML_ASSERT(src5);
|
|
|
|
|
|
|
|
size_t offs_src3 = 0;
|
|
|
|
size_t offs_src4 = 0;
|
|
|
|
size_t offs_src5 = 0;
|
|
|
|
|
|
|
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
|
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
|
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
|
|
|
|
|
|
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
|
|
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
|
|
|
|
|
|
|
const uint64_t nb30 = src3->nb[0];
|
|
|
|
const uint64_t nb31 = src3->nb[1];
|
|
|
|
|
|
|
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
|
|
|
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
|
|
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
|
|
|
|
|
|
|
const uint64_t nb40 = src4->nb[0];
|
|
|
|
const uint64_t nb41 = src4->nb[1];
|
|
|
|
const uint64_t nb42 = src4->nb[2];
|
|
|
|
|
|
|
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
|
|
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
|
|
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
|
|
|
|
|
|
|
const uint64_t nb50 = src5->nb[0];
|
|
|
|
const uint64_t nb51 = src5->nb[1];
|
|
|
|
const uint64_t nb52 = src5->nb[2];
|
|
|
|
|
|
|
|
const int64_t d_state = ne00;
|
|
|
|
const int64_t d_inner = ne01;
|
|
|
|
const int64_t n_seq_tokens = ne11;
|
|
|
|
const int64_t n_seqs = ne02;
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
|
|
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
|
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
|
|
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
|
|
|
|
|
|
|
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
|
|
|
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
|
|
|
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
|
|
|
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
|
|
|
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
|
|
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
|
|
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
|
|
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
|
|
|
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
|
|
|
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
|
|
|
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
|
|
|
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
|
|
|
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
|
|
|
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
|
|
|
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
|
|
|
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_MUL_MAT:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne00 == ne10);
|
|
|
|
|
|
|
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
|
|
GGML_ASSERT(ne13 % ne03 == 0);
|
|
|
|
|
|
|
|
const uint r2 = ne12/ne02;
|
|
|
|
const uint r3 = ne13/ne03;
|
|
|
|
|
|
|
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
|
|
// to the matrix-vector kernel
|
|
|
|
int ne11_mm_min = 1;
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-01-16 14:41:27 +01:00
|
|
|
#if 0
|
2024-10-01 15:00:25 +02:00
|
|
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
|
|
// these numbers do not translate to other devices or model sizes
|
|
|
|
// TODO: need to find a better approach
|
2024-10-07 17:27:51 +02:00
|
|
|
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
2024-01-16 14:41:27 +01:00
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
|
|
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
|
|
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
|
|
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
|
|
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
|
|
case GGML_TYPE_Q5_0: // not tested yet
|
|
|
|
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
|
|
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
|
|
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
|
|
default: ne11_mm_min = 1; break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
2023-08-23 22:08:04 +02:00
|
|
|
|
2024-01-16 14:41:27 +01:00
|
|
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
|
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
2024-10-07 17:27:51 +02:00
|
|
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2024-10-01 15:00:25 +02:00
|
|
|
!ggml_is_transposed(src0) &&
|
|
|
|
!ggml_is_transposed(src1) &&
|
|
|
|
src1t == GGML_TYPE_F32 &&
|
|
|
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
|
|
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
2024-01-16 14:41:27 +01:00
|
|
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
llama : add custom RoPE (#2054)
* Implement customizable RoPE
The original RoPE has pre-defined parameters
theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2]
Our customizable RoPE, ggml_rope_custom_inplace, uses
theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2]
with the default matches the original
scale = 1.0
base = 10000
The new command line arguments
--rope-freq-base
--rope-freq-scale
set the two new RoPE parameter.
Recent researches show changing these two parameters extends the context limit with minimal loss.
1. Extending Context to 8K
kaiokendev
https://kaiokendev.github.io/til#extending-context-to-8k
2. Extending Context Window of Large Language Models via Positional Interpolation
Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian
https://arxiv.org/abs/2306.15595
3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
https://www.reddit.com/user/bloc97
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
For the bold, try adding the following command line parameters to your favorite model:
-c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5
* ggml-metal: fix custom rope
* common: fix argument names in help
* llama: increase MEM_REQ_EVAL for MODEL_3B
It avoids crashing for quantized weights on CPU.
Better ways to calculate the required buffer size would be better.
* llama: make MEM_REQ_EVAL depend on n_ctx
* server: use proper Content-Type in curl examples
Without the header Content-Type: application/json, curl will POST with
Content-Type: application/x-www-form-urlencoded
Though our simple server doesn't care, the httplib.h used has a limit
with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192
With Content-Type: application/json, we can send large json data.
* style : minor fixes, mostly indentations
* ggml : fix asserts
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-15 12:34:16 +02:00
|
|
|
|
2024-03-22 10:35:53 +01:00
|
|
|
// some Metal matrix data types require aligned pointers
|
|
|
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
|
|
switch (src0->type) {
|
2024-07-13 17:32:33 +02:00
|
|
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
|
|
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
2024-03-22 10:35:53 +01:00
|
|
|
default: break;
|
|
|
|
}
|
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
2023-09-28 18:04:36 +02:00
|
|
|
switch (src0->type) {
|
2024-01-16 14:41:27 +01:00
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
2024-01-16 14:41:27 +01:00
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
2024-01-30 14:14:12 +01:00
|
|
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
2024-02-24 15:23:52 +01:00
|
|
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
2024-02-26 17:28:38 +01:00
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
2024-02-18 17:16:55 +01:00
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
2024-03-26 15:21:27 +01:00
|
|
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
2024-02-21 10:39:52 +01:00
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
2024-02-27 15:34:24 +01:00
|
|
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
2024-07-27 04:41:55 +02:00
|
|
|
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
2024-01-16 14:41:27 +01:00
|
|
|
}
|
2023-09-28 18:04:36 +02:00
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
2024-01-16 14:41:27 +01:00
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
2024-10-25 21:26:15 +02:00
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
|
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
|
|
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
2024-01-16 14:41:27 +01:00
|
|
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
|
|
} else {
|
|
|
|
int nth0 = 32;
|
|
|
|
int nth1 = 1;
|
|
|
|
int nrows = 1;
|
|
|
|
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2023-11-13 15:55:52 +01:00
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
2024-01-16 14:41:27 +01:00
|
|
|
// use custom matrix x vector kernel
|
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F32:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
|
|
nrows = 4;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_F16:
|
|
|
|
{
|
|
|
|
nth0 = 32;
|
|
|
|
nth1 = 1;
|
|
|
|
if (src1t == GGML_TYPE_F32) {
|
|
|
|
if (ne11 * ne12 < 4) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
|
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
|
|
nrows = ne11;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
|
|
|
nrows = 4;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
|
|
|
nrows = 4;
|
|
|
|
}
|
|
|
|
} break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
{
|
|
|
|
nth0 = 32;
|
|
|
|
nth1 = 1;
|
|
|
|
if (src1t == GGML_TYPE_F32) {
|
|
|
|
if (ne11 * ne12 < 4) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
|
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
|
|
nrows = ne11;
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
|
|
|
nrows = 4;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
|
|
|
nrows = 4;
|
|
|
|
}
|
|
|
|
} break;
|
2024-01-16 14:41:27 +01:00
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q2_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q3_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_K:
|
|
|
|
{
|
|
|
|
nth0 = 4; //1;
|
|
|
|
nth1 = 8; //32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q6_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ2_XXS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ2_XS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
|
|
|
} break;
|
2024-01-30 14:14:12 +01:00
|
|
|
case GGML_TYPE_IQ3_XXS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
|
|
|
} break;
|
2024-02-24 15:23:52 +01:00
|
|
|
case GGML_TYPE_IQ3_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
|
|
|
} break;
|
2024-02-26 17:28:38 +01:00
|
|
|
case GGML_TYPE_IQ2_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
|
|
|
} break;
|
2024-02-18 17:16:55 +01:00
|
|
|
case GGML_TYPE_IQ1_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
|
|
|
} break;
|
2024-03-26 15:21:27 +01:00
|
|
|
case GGML_TYPE_IQ1_M:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
|
|
|
} break;
|
2024-02-21 10:39:52 +01:00
|
|
|
case GGML_TYPE_IQ4_NL:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
|
|
} break;
|
2024-02-27 15:34:24 +01:00
|
|
|
case GGML_TYPE_IQ4_XS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
|
|
|
} break;
|
2024-01-16 14:41:27 +01:00
|
|
|
default:
|
|
|
|
{
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
2024-07-27 04:41:55 +02:00
|
|
|
GGML_ABORT("not implemented");
|
2024-01-16 14:41:27 +01:00
|
|
|
}
|
2023-11-13 15:55:52 +01:00
|
|
|
};
|
|
|
|
|
2024-01-13 17:03:45 +01:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
2023-12-13 20:54:54 +01:00
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2024-01-16 14:41:27 +01:00
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
2023-12-13 20:54:54 +01:00
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2024-10-25 21:26:15 +02:00
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
|
|
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
|
|
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-03-26 15:21:27 +01:00
|
|
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
2024-10-25 21:26:15 +02:00
|
|
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
|
|
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
2024-01-16 14:41:27 +01:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
|
|
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
2024-02-24 15:23:52 +01:00
|
|
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
|
|
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
2024-01-30 14:14:12 +01:00
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
2024-02-27 15:34:24 +01:00
|
|
|
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
2024-02-21 10:39:52 +01:00
|
|
|
const int mem_size = 32*sizeof(float);
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
2024-01-16 14:41:27 +01:00
|
|
|
else if (src0t == GGML_TYPE_Q4_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q3_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q5_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q6_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
} else {
|
|
|
|
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
} break;
|
|
|
|
case GGML_OP_MUL_MAT_ID:
|
|
|
|
{
|
|
|
|
const int n_as = src0->ne[2];
|
|
|
|
|
|
|
|
// src2 = ids
|
|
|
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
|
|
|
|
|
|
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
|
|
|
|
|
|
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
|
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
|
|
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
|
2024-10-25 21:26:15 +02:00
|
|
|
GGML_ASSERT(ne03 == 1);
|
|
|
|
GGML_ASSERT(ne13 == 1);
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
|
|
// to the matrix-vector kernel
|
|
|
|
// ne20 = n_used_experts
|
|
|
|
// ne21 = n_rows
|
|
|
|
const int dst_rows = ne20*ne21;
|
|
|
|
const int dst_rows_min = n_as;
|
2024-10-07 17:27:51 +02:00
|
|
|
const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
// max size of the rowids array in the kernel shared buffer
|
|
|
|
GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
|
|
|
|
|
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
|
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
|
|
// !!!
|
|
|
|
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
|
|
// indirect matrix multiplication
|
|
|
|
// !!!
|
2024-10-07 17:27:51 +02:00
|
|
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2024-10-01 15:00:25 +02:00
|
|
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
|
|
dst_rows > dst_rows_min) {
|
|
|
|
// some Metal matrix data types require aligned pointers
|
|
|
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
|
|
switch (src0->type) {
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
|
|
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
|
|
|
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default: break;
|
|
|
|
}
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
switch (src0->type) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
2024-10-01 15:00:25 +02:00
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
|
|
|
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
|
|
|
}
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
|
|
|
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
|
|
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
|
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
|
|
|
|
|
|
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
|
|
} else {
|
|
|
|
int nth0 = 32;
|
|
|
|
int nth1 = 1;
|
|
|
|
int nrows = 1;
|
|
|
|
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
// use custom matrix x vector kernel
|
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F32:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_F16:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
nth0 = 32;
|
|
|
|
nth1 = 1;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
|
|
|
} break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
nth0 = 32;
|
|
|
|
nth1 = 1;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
|
{
|
|
|
|
nth0 = 8;
|
|
|
|
nth1 = 8;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q2_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q3_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_K:
|
|
|
|
{
|
|
|
|
nth0 = 4; //1;
|
|
|
|
nth1 = 8; //32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q6_K:
|
|
|
|
{
|
|
|
|
nth0 = 2;
|
|
|
|
nth1 = 32;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ2_XXS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ2_XS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ3_XXS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ3_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ2_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ1_S:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ1_M:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ4_NL:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_IQ4_XS:
|
|
|
|
{
|
|
|
|
nth0 = 4;
|
|
|
|
nth1 = 16;
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
|
|
|
} break;
|
|
|
|
default:
|
|
|
|
{
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ABORT("not implemented");
|
|
|
|
}
|
|
|
|
};
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (ggml_is_quantized(src0t)) {
|
|
|
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
|
|
|
}
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
|
|
|
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
|
|
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
|
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
|
|
|
|
|
|
|
const int64_t _ne1 = 1;
|
|
|
|
const int tgz = dst_rows;
|
|
|
|
|
|
|
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
|
|
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
|
|
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
|
|
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
|
|
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
|
|
|
const int mem_size = 32*sizeof(float);
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q4_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q3_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q5_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
else if (src0t == GGML_TYPE_Q6_K) {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
} else {
|
|
|
|
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_OP_GET_ROWS:
|
|
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
switch (src0->type) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
|
2024-10-01 15:00:25 +02:00
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
|
|
|
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
|
|
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
|
|
default: GGML_ABORT("not implemented");
|
|
|
|
}
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_RMS_NORM:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float eps;
|
|
|
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
int nth = 32; // SIMD width
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
while (nth < ne00/4 && nth < 1024) {
|
|
|
|
nth *= 2;
|
|
|
|
}
|
2024-03-22 10:35:53 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
|
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2023-12-13 20:54:54 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t nrows = ggml_nrows(src0);
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_GROUP_NORM:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
2024-01-13 17:03:45 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float eps;
|
|
|
|
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
int nth = 32; // SIMD width
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
//while (nth < ne00/4 && nth < 1024) {
|
|
|
|
// nth *= 2;
|
|
|
|
//}
|
2024-01-10 15:19:19 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
|
|
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
|
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2024-05-21 22:28:32 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_NORM:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2024-04-03 15:07:05 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float eps;
|
|
|
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2024-04-03 15:07:05 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int nth = MIN(256, ne00);
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
|
|
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
|
|
|
|
|
|
const int64_t nrows = ggml_nrows(src0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_ROPE:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne10 == ne02);
|
|
|
|
|
|
|
|
const int nth = MIN(1024, ne00);
|
|
|
|
|
|
|
|
const int n_past = ((const int32_t *) dst->op_params)[0];
|
|
|
|
const int n_dims = ((const int32_t *) dst->op_params)[1];
|
|
|
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
|
|
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
|
|
const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
|
|
|
|
|
|
|
|
float freq_base;
|
|
|
|
float freq_scale;
|
|
|
|
float ext_factor;
|
|
|
|
float attn_factor;
|
|
|
|
float beta_fast;
|
|
|
|
float beta_slow;
|
|
|
|
|
|
|
|
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
|
|
|
|
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
|
|
|
|
memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float));
|
|
|
|
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
|
|
|
|
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
|
|
|
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
|
|
|
|
|
|
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
if (!is_neox) {
|
|
|
|
switch (src0->type) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
};
|
|
|
|
} else {
|
|
|
|
switch (src0->type) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
};
|
|
|
|
}
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
if (id_src2 != nil) {
|
|
|
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
|
|
|
} else {
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
|
|
|
}
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
|
|
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
|
|
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
|
|
|
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
|
|
|
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
|
|
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
|
|
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
|
|
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
|
|
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
|
|
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_IM2COL:
|
|
|
|
{
|
2024-10-23 12:33:45 +02:00
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
|
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
|
|
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
|
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
|
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
|
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
|
|
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
|
|
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
|
|
const int32_t IW = src1->ne[0];
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
|
|
const int32_t KW = src0->ne[0];
|
2024-01-16 14:41:27 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
|
|
const int32_t OW = dst->ne[1];
|
2024-05-14 18:09:30 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t CHW = IC * KH * KW;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
|
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-23 12:33:45 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
|
|
|
|
|
|
|
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
switch (dst->type) {
|
2024-10-23 12:33:45 +02:00
|
|
|
case GGML_TYPE_F32: {
|
|
|
|
pipeline = (is_gt_mttpt ?
|
|
|
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
|
|
|
:
|
|
|
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_F16: {
|
|
|
|
pipeline = (is_gt_mttpt ?
|
|
|
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
|
|
|
:
|
|
|
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
};
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
2024-10-23 12:33:45 +02:00
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
|
|
|
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
|
|
|
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
|
|
|
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
|
|
|
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
|
|
|
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
|
|
|
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
|
|
|
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
|
|
|
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
|
|
|
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
|
|
|
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
|
|
|
|
|
|
|
if (is_gt_mttpt) {
|
|
|
|
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
|
|
|
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
|
|
|
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
|
|
|
|
|
|
|
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
|
|
|
|
|
|
|
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
|
} else {
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
} break;
|
|
|
|
case GGML_OP_UPSCALE:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
const float sf0 = (float)ne0/src0->ne[0];
|
|
|
|
const float sf1 = (float)ne1/src0->ne[1];
|
|
|
|
const float sf2 = (float)ne2/src0->ne[2];
|
|
|
|
const float sf3 = (float)ne3/src0->ne[3];
|
|
|
|
|
|
|
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
|
|
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
|
|
|
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
|
|
|
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
|
|
|
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
|
|
|
|
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_PAD:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
|
|
|
|
|
|
const int nth = MIN(1024, ne0);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_ARANGE:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float start;
|
|
|
|
float step;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
|
|
memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
|
|
|
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
|
|
|
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int nth = MIN(1024, ne0);
|
2024-05-11 09:32:41 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2024-05-11 09:32:41 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int dim = dst->op_params[0];
|
|
|
|
const int max_period = dst->op_params[1];
|
2024-05-11 09:32:41 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int half = dim / 2;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
|
|
|
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
|
|
|
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int nth = MIN(1024, half);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_ARGSORT:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int nrows = ggml_nrows(src0);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// bitonic sort requires the number of elements to be power of 2
|
|
|
|
int64_t ne00_padded = 1;
|
|
|
|
while (ne00_padded < ne00) {
|
|
|
|
ne00_padded *= 2;
|
|
|
|
}
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
|
|
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
|
|
const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
switch (order) {
|
|
|
|
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
|
|
|
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
|
};
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
float slope;
|
|
|
|
memcpy(&slope, dst->op_params, sizeof(float));
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t n = ggml_nelements(dst);
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
} break;
|
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
GGML_ASSERT(ne11 % 32 == 0);
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2024-11-06 09:24:23 +01:00
|
|
|
GGML_ASSERT(src1->type == src2->type);
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
|
|
|
|
|
|
|
struct ggml_tensor * src3 = node->src[3];
|
|
|
|
|
|
|
|
size_t offs_src3 = 0;
|
|
|
|
|
|
|
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
|
|
|
|
|
|
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
|
|
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
|
|
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
|
|
|
|
|
|
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
|
|
|
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
|
|
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
|
|
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
|
|
|
|
|
|
|
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
|
|
|
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
|
|
|
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
|
|
|
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
|
|
|
|
|
|
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
|
|
|
|
|
|
|
float scale;
|
|
|
|
float max_bias;
|
|
|
|
float logit_softcap;
|
|
|
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
|
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
|
|
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
|
|
|
|
|
|
|
if (logit_softcap != 0.0f) {
|
|
|
|
scale /= logit_softcap;
|
|
|
|
}
|
|
|
|
|
|
|
|
const uint32_t n_head = src0->ne[2];
|
|
|
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
|
|
|
|
|
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
bool use_vec_kernel = false;
|
|
|
|
|
2024-11-09 10:53:02 +01:00
|
|
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
|
|
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
2024-10-01 15:00:25 +02:00
|
|
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
2024-11-06 09:24:23 +01:00
|
|
|
switch (src1->type) {
|
|
|
|
case GGML_TYPE_F16:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
2024-11-08 12:47:22 +01:00
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
2024-11-06 09:24:23 +01:00
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_0:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default:
|
2024-11-06 09:24:23 +01:00
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
|
|
|
GGML_ABORT("add template specialization for this type");
|
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
use_vec_kernel = true;
|
|
|
|
|
|
|
|
switch (ne00) {
|
2024-11-06 09:24:23 +01:00
|
|
|
case 128:
|
|
|
|
{
|
|
|
|
switch (src1->type) {
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2024-11-08 12:47:22 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
|
2024-11-06 09:24:23 +01:00
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
|
|
|
GGML_ABORT("add template specialization for this type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case 256:
|
|
|
|
{
|
|
|
|
switch (src1->type) {
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2024-11-08 12:47:22 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
|
2024-11-06 09:24:23 +01:00
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
|
|
|
default:
|
|
|
|
{
|
|
|
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
|
|
|
GGML_ABORT("add template specialization for this type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default:
|
|
|
|
{
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
|
|
|
if (id_src3) {
|
|
|
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
|
|
} else {
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
|
|
|
}
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
|
|
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
|
|
|
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
|
|
|
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2024-11-08 12:47:22 +01:00
|
|
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
|
|
|
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
|
|
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
|
|
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
|
|
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
|
|
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
|
|
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
if (!use_vec_kernel) {
|
|
|
|
// half8x8 kernel
|
|
|
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
|
|
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
|
|
|
|
|
|
GGML_ASSERT(nqptg <= 32);
|
|
|
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
|
2024-11-08 12:47:22 +01:00
|
|
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
|
|
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
|
|
|
//
|
2024-11-06 09:24:23 +01:00
|
|
|
// 16*32*(nsg)
|
|
|
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
|
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
|
|
//
|
2024-11-08 12:47:22 +01:00
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
2024-11-06 09:24:23 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
int64_t nsgmax = 2;
|
|
|
|
|
|
|
|
while (true) {
|
2024-11-06 09:24:23 +01:00
|
|
|
const size_t smem = FATTN_SMEM(nsgmax);
|
2024-10-07 17:27:51 +02:00
|
|
|
if (smem > device.maxThreadgroupMemoryLength) {
|
2024-10-01 15:00:25 +02:00
|
|
|
break;
|
2024-01-16 14:41:27 +01:00
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
nsgmax *= 2;
|
|
|
|
}
|
|
|
|
nsgmax /= 2;
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// simdgroups per threadgroup (a.k.a. warps)
|
|
|
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
|
|
|
|
2024-11-06 09:24:23 +01:00
|
|
|
const size_t smem = FATTN_SMEM(nsg);
|
2024-10-01 15:00:25 +02:00
|
|
|
|
2024-11-06 09:24:23 +01:00
|
|
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
2024-10-07 17:27:51 +02:00
|
|
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
2024-11-06 09:24:23 +01:00
|
|
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
|
#undef FATTN_SMEM
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
|
} else {
|
2024-11-06 09:24:23 +01:00
|
|
|
// half4x4 kernel
|
2024-10-01 15:00:25 +02:00
|
|
|
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
|
|
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
|
|
|
|
|
|
GGML_ASSERT(nqptg <= 32);
|
|
|
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
|
2024-11-06 09:24:23 +01:00
|
|
|
// ne00 + 2*ncpsg*(nsg)
|
|
|
|
// for each query, we load it as f16 in shared memory (ne00)
|
2024-11-08 12:47:22 +01:00
|
|
|
// and store the soft_max values and the mask
|
2024-11-06 09:24:23 +01:00
|
|
|
//
|
2024-11-08 12:47:22 +01:00
|
|
|
// ne00*(nsg)
|
|
|
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
2024-11-06 09:24:23 +01:00
|
|
|
//
|
2024-11-08 12:47:22 +01:00
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
2024-11-06 09:24:23 +01:00
|
|
|
|
|
|
|
int64_t nsgmax = 2;
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
const size_t smem = FATTN_SMEM(nsgmax);
|
|
|
|
if (smem > device.maxThreadgroupMemoryLength) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
nsgmax *= 2;
|
|
|
|
}
|
|
|
|
nsgmax /= 2;
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// simdgroups per threadgroup (a.k.a. warps)
|
2024-11-06 09:24:23 +01:00
|
|
|
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
int64_t nsg = 1;
|
|
|
|
while (nsg <= nsgt) {
|
|
|
|
nsg *= 2;
|
2024-01-16 14:41:27 +01:00
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
nsg /= 2;
|
|
|
|
|
2024-11-06 09:24:23 +01:00
|
|
|
const size_t smem = FATTN_SMEM(nsg);
|
2024-10-01 15:00:25 +02:00
|
|
|
|
2024-11-06 09:24:23 +01:00
|
|
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
2024-10-07 17:27:51 +02:00
|
|
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
2024-11-06 09:24:23 +01:00
|
|
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
|
#undef FATTN_SMEM
|
2024-10-01 15:00:25 +02:00
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
case GGML_OP_DUP:
|
|
|
|
case GGML_OP_CPY:
|
|
|
|
case GGML_OP_CONT:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
|
|
|
|
|
|
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F32:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
|
|
|
|
|
|
|
switch (dstt) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
|
2024-10-01 15:00:25 +02:00
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
|
|
|
|
default: GGML_ABORT("not implemented");
|
|
|
|
};
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_F16:
|
|
|
|
{
|
|
|
|
switch (dstt) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
|
|
|
default: GGML_ABORT("not implemented");
|
|
|
|
};
|
|
|
|
} break;
|
2024-11-06 18:53:51 +01:00
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
{
|
|
|
|
switch (dstt) {
|
|
|
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
|
};
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default: GGML_ABORT("not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
|
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
|
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
|
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
|
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
|
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
|
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
|
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
|
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
} break;
|
2024-10-23 12:33:45 +02:00
|
|
|
case GGML_OP_POOL_2D:
|
|
|
|
{
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
|
|
|
|
|
|
|
const int32_t * opts = dst->op_params;
|
|
|
|
enum ggml_op_pool op = opts[0];
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
switch (src0t) {
|
|
|
|
case GGML_TYPE_F32: {
|
|
|
|
switch(op) {
|
|
|
|
case GGML_OP_POOL_AVG:
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
|
|
|
case GGML_OP_POOL_MAX:
|
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
const int32_t k0 = opts[1];
|
|
|
|
const int32_t k1 = opts[2];
|
|
|
|
const int32_t s0 = opts[3];
|
|
|
|
const int32_t s1 = opts[4];
|
|
|
|
const int32_t p0 = opts[5];
|
|
|
|
const int32_t p1 = opts[6];
|
|
|
|
|
|
|
|
const int64_t IH = src0->ne[1];
|
|
|
|
const int64_t IW = src0->ne[0];
|
|
|
|
|
|
|
|
const int64_t N = dst->ne[3];
|
|
|
|
const int64_t OC = dst->ne[2];
|
|
|
|
const int64_t OH = dst->ne[1];
|
|
|
|
const int64_t OW = dst->ne[0];
|
|
|
|
|
|
|
|
const int64_t parallel_elements = N * OC * OH * OW;
|
|
|
|
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
|
|
|
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
|
|
|
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
|
|
|
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
|
|
|
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
|
|
|
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
|
|
|
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
|
|
|
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
|
|
|
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
|
|
|
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
|
|
|
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
|
|
|
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
|
} break;
|
2024-10-01 15:00:25 +02:00
|
|
|
default:
|
|
|
|
{
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
2024-10-01 15:00:25 +02:00
|
|
|
GGML_ABORT("fatal error");
|
2023-06-15 19:29:48 +02:00
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static enum ggml_status ggml_metal_graph_compute(
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_backend_t backend,
|
|
|
|
struct ggml_cgraph * gf) {
|
|
|
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// number of nodes encoded by the main thread (empirically determined)
|
|
|
|
const int n_main = 128;
|
|
|
|
|
|
|
|
// number of threads in addition to the main thread
|
|
|
|
const int n_cb = ctx->n_cb;
|
|
|
|
|
|
|
|
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
|
|
|
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
|
|
|
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
|
|
|
// each thread creates it's own command buffer and enqueues the ops in parallel
|
|
|
|
//
|
|
|
|
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
|
|
|
|
|
|
|
|
@autoreleasepool {
|
|
|
|
ctx->gf = gf;
|
2023-06-15 19:29:48 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
|
|
|
|
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
|
|
|
|
|
|
|
|
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
|
|
|
|
|
|
|
const bool should_capture = ctx->capture_next_compute;
|
|
|
|
if (should_capture) {
|
|
|
|
ctx->capture_next_compute = false;
|
|
|
|
|
|
|
|
if (!ctx->capture_started) {
|
|
|
|
// create capture scope
|
2024-10-07 17:27:51 +02:00
|
|
|
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
|
2024-10-01 15:00:25 +02:00
|
|
|
|
|
|
|
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
|
|
|
descriptor.captureObject = ctx->capture_scope;
|
|
|
|
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
|
|
|
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
|
|
|
|
|
|
|
NSError * error = nil;
|
|
|
|
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
2024-10-01 15:00:25 +02:00
|
|
|
} else {
|
|
|
|
[ctx->capture_scope beginScope];
|
|
|
|
ctx->capture_started = true;
|
|
|
|
}
|
2024-01-29 10:22:23 +01:00
|
|
|
}
|
2024-01-16 14:41:27 +01:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// the main thread commits the first few commands immediately
|
|
|
|
// command_buffer[n_cb]
|
|
|
|
{
|
|
|
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
|
ctx->command_buffers[n_cb] = command_buffer;
|
|
|
|
|
|
|
|
[command_buffer enqueue];
|
|
|
|
ctx->encode_async(n_cb);
|
2024-08-07 08:55:49 +02:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
// prepare the rest of the command buffers asynchronously
|
|
|
|
// command_buffer[0.. n_cb)
|
|
|
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
|
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
|
ctx->command_buffers[cb_idx] = command_buffer;
|
|
|
|
|
|
|
|
// always enqueue the first two command buffers
|
|
|
|
// enqueue all of the command buffers if we don't need to abort
|
|
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
|
|
[command_buffer enqueue];
|
|
|
|
}
|
2024-08-07 08:55:49 +02:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
|
|
|
|
|
|
|
// wait for completion and check status of each command buffer
|
|
|
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
|
|
{
|
|
|
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
|
|
|
[command_buffer waitUntilCompleted];
|
|
|
|
|
|
|
|
MTLCommandBufferStatus status = [command_buffer status];
|
|
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
2024-10-01 15:00:25 +02:00
|
|
|
if (status == MTLCommandBufferStatusError) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return GGML_STATUS_FAILED;
|
|
|
|
}
|
2024-08-07 08:55:49 +02:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
for (int i = 0; i < n_cb; ++i) {
|
|
|
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
|
|
|
[command_buffer waitUntilCompleted];
|
2023-08-28 09:59:08 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
MTLCommandBufferStatus status = [command_buffer status];
|
|
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2024-10-01 15:00:25 +02:00
|
|
|
if (status == MTLCommandBufferStatusError) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return GGML_STATUS_FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
|
|
|
if (!next_buffer) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
|
|
|
if (next_queued) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
2024-10-01 15:00:25 +02:00
|
|
|
return GGML_STATUS_ABORTED;
|
|
|
|
}
|
|
|
|
|
|
|
|
[next_buffer commit];
|
|
|
|
}
|
2024-01-29 10:22:23 +01:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
if (!should_capture && ctx->capture_started) {
|
|
|
|
[ctx->capture_scope endScope];
|
|
|
|
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
|
|
|
}
|
2024-02-10 11:53:28 +01:00
|
|
|
}
|
2024-10-01 15:00:25 +02:00
|
|
|
|
2024-03-04 10:05:42 +01:00
|
|
|
return GGML_STATUS_SUCCESS;
|
2023-06-04 22:34:30 +02:00
|
|
|
}
|
2023-10-08 19:19:14 +02:00
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
// backend interface
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
2023-12-07 21:26:54 +01:00
|
|
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
|
|
|
[ctx->buffers[i].metal release];
|
|
|
|
}
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
if (ctx->owned) {
|
2024-05-08 21:08:10 +02:00
|
|
|
#if TARGET_OS_OSX
|
|
|
|
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
|
|
|
|
#else
|
2023-12-21 21:07:46 +01:00
|
|
|
free(ctx->all_data);
|
2024-05-08 21:08:10 +02:00
|
|
|
#endif
|
2023-12-21 21:07:46 +01:00
|
|
|
}
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
free(ctx);
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
2024-01-12 20:07:38 +01:00
|
|
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
|
|
|
|
|
|
return ctx->all_data;
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
2023-12-07 21:26:54 +01:00
|
|
|
memcpy((char *)tensor->data + offset, data, size);
|
|
|
|
|
|
|
|
UNUSED(buffer);
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
2023-12-07 21:26:54 +01:00
|
|
|
memcpy(data, (const char *)tensor->data + offset, size);
|
|
|
|
|
|
|
|
UNUSED(buffer);
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
2024-01-12 20:07:38 +01:00
|
|
|
if (ggml_backend_buffer_is_host(src->buffer)) {
|
|
|
|
memcpy(dst->data, src->data, ggml_nbytes(src));
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
return false;
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2023-10-08 19:19:14 +02:00
|
|
|
UNUSED(buffer);
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
2023-12-21 21:07:46 +01:00
|
|
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
|
|
|
|
|
|
memset(ctx->all_data, value, ctx->all_size);
|
|
|
|
}
|
|
|
|
|
|
|
|
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
|
|
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
|
|
|
/* .init_tensor = */ NULL,
|
2024-09-20 18:04:44 +02:00
|
|
|
/* .memset_tensor = */ NULL,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
|
|
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
2024-01-12 20:07:38 +01:00
|
|
|
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
2023-12-21 21:07:46 +01:00
|
|
|
/* .clear = */ ggml_backend_metal_buffer_clear,
|
2024-01-12 20:07:38 +01:00
|
|
|
/* .reset = */ NULL,
|
2023-10-08 19:19:14 +02:00
|
|
|
};
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
// default buffer type
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
2024-01-12 20:07:38 +01:00
|
|
|
return "Metal";
|
|
|
|
|
|
|
|
UNUSED(buft);
|
|
|
|
}
|
|
|
|
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
|
|
|
#ifndef GGML_METAL_NDEBUG
|
2024-01-16 14:33:02 +01:00
|
|
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
|
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
__func__,
|
|
|
|
size_aligned / 1024.0 / 1024.0,
|
2024-01-16 14:33:02 +01:00
|
|
|
device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
|
|
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
|
|
|
|
|
|
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2024-01-16 14:33:02 +01:00
|
|
|
}
|
|
|
|
} else {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
__func__,
|
|
|
|
size_aligned / 1024.0 / 1024.0,
|
|
|
|
device.currentAllocatedSize / 1024.0 / 1024.0);
|
2024-01-16 14:33:02 +01:00
|
|
|
}
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
#endif
|
2024-01-16 14:33:02 +01:00
|
|
|
#endif
|
|
|
|
UNUSED(device);
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
UNUSED(size_aligned);
|
2024-01-16 14:33:02 +01:00
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
2024-10-05 13:33:54 +02:00
|
|
|
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
2023-12-07 21:26:54 +01:00
|
|
|
|
|
|
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
size_t size_aligned = size;
|
|
|
|
if ((size_aligned % size_page) != 0) {
|
|
|
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
|
|
}
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
|
|
|
ctx->all_size = size_aligned;
|
|
|
|
ctx->owned = true;
|
|
|
|
ctx->n_buffers = 1;
|
|
|
|
|
2024-05-08 21:08:10 +02:00
|
|
|
if (ctx->all_data != NULL) {
|
2024-09-16 08:05:56 +02:00
|
|
|
ctx->buffers[0].data = ctx->all_data;
|
|
|
|
ctx->buffers[0].size = size;
|
|
|
|
ctx->buffers[0].metal = nil;
|
|
|
|
|
|
|
|
if (size_aligned > 0) {
|
|
|
|
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
2024-10-07 17:27:51 +02:00
|
|
|
length:size_aligned
|
|
|
|
options:MTLResourceStorageModeShared
|
|
|
|
deallocator:nil];
|
2024-09-16 08:05:56 +02:00
|
|
|
}
|
2024-05-08 21:08:10 +02:00
|
|
|
}
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2023-12-21 21:07:46 +01:00
|
|
|
free(ctx);
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_backend_metal_device_rel(buft->device->context);
|
2023-12-21 21:07:46 +01:00
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
2023-10-08 19:19:14 +02:00
|
|
|
return 32;
|
2023-12-07 21:26:54 +01:00
|
|
|
UNUSED(buft);
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
2024-10-07 17:27:51 +02:00
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
|
|
|
const size_t max_size = device.maxBufferLength;
|
|
|
|
ggml_backend_metal_device_rel(buft->device->context);
|
2024-01-29 09:05:13 +01:00
|
|
|
|
|
|
|
return max_size;
|
|
|
|
|
|
|
|
UNUSED(buft);
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
2023-12-21 21:07:46 +01:00
|
|
|
return true;
|
|
|
|
|
|
|
|
UNUSED(buft);
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
2023-12-07 21:26:54 +01:00
|
|
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
|
|
|
/* .iface = */ {
|
2024-01-12 20:07:38 +01:00
|
|
|
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
|
|
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2024-01-29 09:05:13 +01:00
|
|
|
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2023-12-21 21:07:46 +01:00
|
|
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
2023-12-07 21:26:54 +01:00
|
|
|
},
|
2024-10-07 17:27:51 +02:00
|
|
|
/* .device = */ &g_ggml_backend_metal_device,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .context = */ NULL,
|
|
|
|
};
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
return &ggml_backend_buffer_type_metal;
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
2024-10-30 02:01:23 +01:00
|
|
|
static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
|
|
return "Metal_Mapped";
|
|
|
|
|
|
|
|
UNUSED(buft);
|
|
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
|
|
|
|
static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
|
|
|
|
/* .iface = */ {
|
|
|
|
/* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
|
|
|
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
|
|
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
|
|
|
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
|
|
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
|
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
|
|
|
},
|
|
|
|
/* .device = */ &g_ggml_backend_metal_device,
|
|
|
|
/* .context = */ NULL,
|
|
|
|
};
|
|
|
|
|
|
|
|
return &ggml_backend_buffer_from_ptr_type_metal;
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
|
2024-10-03 01:49:47 +02:00
|
|
|
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
2024-10-05 13:33:54 +02:00
|
|
|
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
ctx->all_data = data;
|
|
|
|
ctx->all_size = size;
|
|
|
|
ctx->owned = false;
|
|
|
|
ctx->n_buffers = 0;
|
|
|
|
|
|
|
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
2024-01-12 20:07:38 +01:00
|
|
|
|
|
|
|
// page-align the data ptr
|
|
|
|
{
|
|
|
|
const uintptr_t offs = (uintptr_t) data % size_page;
|
|
|
|
data = (void *) ((char *) data - offs);
|
|
|
|
size += offs;
|
|
|
|
}
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
size_t size_aligned = size;
|
|
|
|
if ((size_aligned % size_page) != 0) {
|
|
|
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
// the buffer fits into the max buffer size allowed by the device
|
|
|
|
if (size_aligned <= device.maxBufferLength) {
|
2024-09-16 08:05:56 +02:00
|
|
|
ctx->buffers[ctx->n_buffers].data = data;
|
|
|
|
ctx->buffers[ctx->n_buffers].size = size;
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = nil;
|
2023-12-21 21:07:46 +01:00
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
if (size_aligned > 0) {
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
2023-12-21 21:07:46 +01:00
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2024-09-16 08:05:56 +02:00
|
|
|
return false;
|
|
|
|
}
|
2023-12-21 21:07:46 +01:00
|
|
|
}
|
|
|
|
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2023-12-21 21:07:46 +01:00
|
|
|
|
|
|
|
++ctx->n_buffers;
|
|
|
|
} else {
|
|
|
|
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
|
|
// one of the views
|
|
|
|
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
|
|
const size_t size_step = device.maxBufferLength - size_ovlp;
|
|
|
|
const size_t size_view = device.maxBufferLength;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < size; i += size_step) {
|
|
|
|
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
|
|
|
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = nil;
|
2023-12-21 21:07:46 +01:00
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
if (size_step_aligned > 0) {
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
2023-12-21 21:07:46 +01:00
|
|
|
|
2024-09-16 08:05:56 +02:00
|
|
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
2024-09-16 08:05:56 +02:00
|
|
|
return false;
|
|
|
|
}
|
2023-12-21 21:07:46 +01:00
|
|
|
}
|
|
|
|
|
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 11:16:08 +02:00
|
|
|
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
|
|
|
|
2023-12-21 21:07:46 +01:00
|
|
|
if (i + size_step < size) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_INFO("\n");
|
2023-12-21 21:07:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
++ctx->n_buffers;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-30 02:01:23 +01:00
|
|
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2023-12-21 21:07:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// backend
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2023-12-07 21:26:54 +01:00
|
|
|
return "Metal";
|
|
|
|
|
2023-10-08 19:19:14 +02:00
|
|
|
UNUSED(backend);
|
|
|
|
}
|
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
2024-10-07 17:27:51 +02:00
|
|
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
|
|
|
ggml_backend_metal_device_rel(ctx_dev);
|
2023-12-07 21:26:54 +01:00
|
|
|
ggml_metal_free(ctx);
|
2024-10-07 17:27:51 +02:00
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
free(backend);
|
|
|
|
}
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2024-10-03 01:49:47 +02:00
|
|
|
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
2024-10-07 17:27:51 +02:00
|
|
|
return ggml_metal_graph_compute(backend, cgraph);
|
2024-06-13 03:11:35 +02:00
|
|
|
}
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
|
|
|
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
|
|
|
|
|
|
if (ctx->n_cb != n_cb) {
|
|
|
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
|
|
|
|
|
|
|
|
if (ctx->n_cb > 2) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-07 14:26:31 +02:00
|
|
|
if (ctx->encode_async) {
|
|
|
|
Block_release(ctx->encode_async);
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx->encode_async = Block_copy(^(size_t iter) {
|
|
|
|
const int cb_idx = iter;
|
|
|
|
const int n_cb_l = ctx->n_cb;
|
|
|
|
|
|
|
|
const int n_nodes_0 = ctx->n_nodes_0;
|
|
|
|
const int n_nodes_1 = ctx->n_nodes_1;
|
|
|
|
|
|
|
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
|
|
|
|
|
|
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
|
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
|
|
|
|
|
|
|
int node_start = 0;
|
|
|
|
int node_end = n_nodes_0;
|
|
|
|
|
|
|
|
if (cb_idx < n_cb_l) {
|
|
|
|
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
|
|
|
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
|
|
|
}
|
|
|
|
|
|
|
|
const bool should_capture = ctx->capture_next_compute;
|
|
|
|
|
|
|
|
for (int idx = node_start; idx < node_end; ++idx) {
|
|
|
|
if (should_capture) {
|
|
|
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_metal_encode_node(backend, idx, encoder);
|
2024-10-07 14:26:31 +02:00
|
|
|
|
|
|
|
if (should_capture) {
|
|
|
|
[encoder popDebugGroup];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
[encoder endEncoding];
|
|
|
|
|
|
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
|
|
[command_buffer commit];
|
|
|
|
}
|
|
|
|
});
|
2024-10-01 15:00:25 +02:00
|
|
|
}
|
|
|
|
|
2024-01-12 20:07:38 +01:00
|
|
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .get_name = */ ggml_backend_metal_name,
|
|
|
|
/* .free = */ ggml_backend_metal_free,
|
|
|
|
/* .set_tensor_async = */ NULL,
|
|
|
|
/* .get_tensor_async = */ NULL,
|
2024-01-12 20:07:38 +01:00
|
|
|
/* .cpy_tensor_async = */ NULL,
|
2023-12-21 21:07:46 +01:00
|
|
|
/* .synchronize = */ NULL,
|
|
|
|
/* .graph_plan_create = */ NULL,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .graph_plan_free = */ NULL,
|
2024-06-13 03:11:35 +02:00
|
|
|
/* .graph_plan_update = */ NULL,
|
2023-12-07 21:26:54 +01:00
|
|
|
/* .graph_plan_compute = */ NULL,
|
|
|
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
2024-03-13 18:54:21 +01:00
|
|
|
/* .event_record = */ NULL,
|
|
|
|
/* .event_wait = */ NULL,
|
2023-10-08 19:19:14 +02:00
|
|
|
};
|
|
|
|
|
2024-02-24 17:27:36 +01:00
|
|
|
static ggml_guid_t ggml_backend_metal_guid(void) {
|
|
|
|
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
|
|
|
return &guid;
|
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
// TODO: remove in the future
|
2023-10-08 19:19:14 +02:00
|
|
|
ggml_backend_t ggml_backend_metal_init(void) {
|
2024-10-07 17:27:51 +02:00
|
|
|
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
|
|
|
|
|
|
|
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
2023-12-07 21:26:54 +01:00
|
|
|
if (ctx == NULL) {
|
2024-10-03 17:39:03 +02:00
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
2023-12-07 21:26:54 +01:00
|
|
|
return NULL;
|
|
|
|
}
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
2023-10-08 19:19:14 +02:00
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
*backend = (struct ggml_backend) {
|
2024-02-24 17:27:36 +01:00
|
|
|
/* .guid = */ ggml_backend_metal_guid(),
|
2024-01-12 20:07:38 +01:00
|
|
|
/* .interface = */ ggml_backend_metal_i,
|
2024-10-07 17:27:51 +02:00
|
|
|
/* .device = */ dev,
|
2023-10-08 19:19:14 +02:00
|
|
|
/* .context = */ ctx,
|
|
|
|
};
|
|
|
|
|
2024-10-01 15:00:25 +02:00
|
|
|
ggml_backend_metal_set_n_cb(backend, 1);
|
|
|
|
|
|
|
|
return backend;
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
2024-02-24 17:27:36 +01:00
|
|
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
2023-10-08 19:19:14 +02:00
|
|
|
}
|
|
|
|
|
2024-08-07 08:55:49 +02:00
|
|
|
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
2024-08-07 08:57:00 +02:00
|
|
|
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
2024-08-07 08:55:49 +02:00
|
|
|
|
|
|
|
ctx->abort_callback = abort_callback;
|
|
|
|
ctx->abort_callback_data = user_data;
|
|
|
|
}
|
|
|
|
|
2023-12-07 21:26:54 +01:00
|
|
|
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|
|
|
|
|
2024-01-29 10:22:23 +01:00
|
|
|
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
2024-08-07 08:57:00 +02:00
|
|
|
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
2024-10-01 15:00:25 +02:00
|
|
|
ctx->capture_next_compute = true;
|
2024-01-29 10:22:23 +01:00
|
|
|
}
|
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
// backend device
|
|
|
|
|
|
|
|
static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
|
|
return "Metal";
|
2023-12-07 21:26:54 +01:00
|
|
|
|
2024-10-07 17:27:51 +02:00
|
|
|
GGML_UNUSED(dev);
|
|
|
|
}
|
|
|
|
|
|
|
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
|
|
|
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
|
ggml_backend_metal_device_acq(ctx_dev);
|
|
|
|
ggml_backend_metal_device_rel(ctx_dev);
|
|
|
|
|
|
|
|
return ctx_dev->name;
|
|
|
|
}
|
|
|
|
|
|
|
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
|
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
|
|
|
|
|
*total = device.recommendedMaxWorkingSetSize;
|
|
|
|
*free = *total - device.currentAllocatedSize;
|
|
|
|
|
|
|
|
ggml_backend_metal_device_rel(ctx_dev);
|
|
|
|
} else {
|
|
|
|
*free = 1;
|
|
|
|
*total = 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
2024-10-30 02:01:23 +01:00
|
|
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
2024-10-07 17:27:51 +02:00
|
|
|
|
|
|
|
GGML_UNUSED(dev);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
|
|
props->name = ggml_backend_metal_device_get_name(dev);
|
|
|
|
props->description = ggml_backend_metal_device_get_description(dev);
|
|
|
|
props->type = ggml_backend_metal_device_get_type(dev);
|
|
|
|
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
|
|
props->caps = (struct ggml_backend_dev_caps) {
|
|
|
|
/* .async = */ false,
|
|
|
|
/* .host_buffer = */ false,
|
|
|
|
/* .buffer_from_host_ptr = */ true,
|
|
|
|
/* .events = */ false,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
|
|
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
|
|
|
if (ctx == NULL) {
|
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
|
|
|
|
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
|
|
|
|
|
|
|
*backend = (struct ggml_backend) {
|
|
|
|
/* .guid = */ ggml_backend_metal_guid(),
|
|
|
|
/* .interface = */ ggml_backend_metal_i,
|
|
|
|
/* .device = */ dev,
|
|
|
|
/* .context = */ ctx,
|
|
|
|
};
|
|
|
|
|
|
|
|
ggml_backend_metal_set_n_cb(backend, 1);
|
|
|
|
|
|
|
|
return backend;
|
2023-12-07 21:26:54 +01:00
|
|
|
|
|
|
|
GGML_UNUSED(params);
|
2024-10-07 17:27:51 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
|
|
return ggml_backend_metal_buffer_type();
|
|
|
|
|
|
|
|
GGML_UNUSED(dev);
|
|
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
|
|
|
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
|
|
|
|
|
|
|
ctx->all_data = ptr;
|
|
|
|
ctx->all_size = size;
|
|
|
|
ctx->owned = false;
|
|
|
|
ctx->n_buffers = 0;
|
|
|
|
|
|
|
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
|
|
|
|
|
|
// page-align the data ptr
|
|
|
|
{
|
|
|
|
const uintptr_t offs = (uintptr_t) ptr % size_page;
|
|
|
|
ptr = (void *) ((char *) ptr - offs);
|
|
|
|
size += offs;
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t size_aligned = size;
|
|
|
|
if ((size_aligned % size_page) != 0) {
|
|
|
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
|
|
}
|
|
|
|
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
|
|
|
|
|
// the buffer fits into the max buffer size allowed by the device
|
|
|
|
if (size_aligned <= device.maxBufferLength) {
|
|
|
|
ctx->buffers[ctx->n_buffers].data = ptr;
|
|
|
|
ctx->buffers[ctx->n_buffers].size = size;
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = nil;
|
|
|
|
|
|
|
|
if (size_aligned > 0) {
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
|
|
|
|
|
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
|
|
|
|
|
|
|
++ctx->n_buffers;
|
|
|
|
} else {
|
|
|
|
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
|
|
// one of the views
|
|
|
|
const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
|
|
const size_t size_step = device.maxBufferLength - size_ovlp;
|
|
|
|
const size_t size_view = device.maxBufferLength;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < size; i += size_step) {
|
|
|
|
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
|
|
|
|
|
|
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
|
|
|
|
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = nil;
|
|
|
|
|
|
|
|
if (size_step_aligned > 0) {
|
|
|
|
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
|
|
|
|
|
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
|
|
|
|
|
|
|
if (i + size_step < size) {
|
|
|
|
GGML_LOG_INFO("\n");
|
|
|
|
}
|
|
|
|
|
|
|
|
++ctx->n_buffers;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-06 12:10:07 +01:00
|
|
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2024-10-07 17:27:51 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
|
|
|
|
|
|
return ggml_metal_supports_op(ctx_dev, op);
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
2024-11-06 12:10:07 +01:00
|
|
|
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
|
|
|
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
2024-10-07 17:27:51 +02:00
|
|
|
|
|
|
|
UNUSED(dev);
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
|
|
return false;
|
|
|
|
|
|
|
|
GGML_UNUSED(dev);
|
|
|
|
GGML_UNUSED(op);
|
|
|
|
}
|
|
|
|
|
|
|
|
static struct ggml_backend_device_i ggml_backend_metal_device_i = {
|
|
|
|
/* .get_name = */ ggml_backend_metal_device_get_name,
|
|
|
|
/* .get_description = */ ggml_backend_metal_device_get_description,
|
|
|
|
/* .get_memory = */ ggml_backend_metal_device_get_memory,
|
|
|
|
/* .get_type = */ ggml_backend_metal_device_get_type,
|
|
|
|
/* .get_props = */ ggml_backend_metal_device_get_props,
|
|
|
|
/* .init_backend = */ ggml_backend_metal_device_init,
|
|
|
|
/* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
|
|
|
|
/* .get_host_buffer_type = */ NULL,
|
|
|
|
/* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
|
|
|
|
/* .supports_op = */ ggml_backend_metal_device_supports_op,
|
|
|
|
/* .supports_buft = */ ggml_backend_metal_device_supports_buft,
|
|
|
|
/* .offload_op = */ ggml_backend_metal_device_offload_op,
|
|
|
|
/* .event_new = */ NULL,
|
|
|
|
/* .event_free = */ NULL,
|
|
|
|
/* .event_synchronize = */ NULL,
|
|
|
|
};
|
|
|
|
|
|
|
|
// backend registry
|
|
|
|
|
|
|
|
static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
|
|
|
|
return "Metal";
|
|
|
|
|
|
|
|
GGML_UNUSED(reg);
|
|
|
|
}
|
|
|
|
|
|
|
|
static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
|
|
|
|
return 1;
|
|
|
|
|
|
|
|
GGML_UNUSED(reg);
|
|
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
|
|
|
|
GGML_ASSERT(index == 0);
|
|
|
|
|
|
|
|
return &g_ggml_backend_metal_device;
|
|
|
|
|
|
|
|
GGML_UNUSED(reg);
|
|
|
|
GGML_UNUSED(index);
|
|
|
|
}
|
|
|
|
|
|
|
|
static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
|
|
/* .get_name = */ ggml_backend_metal_reg_get_name,
|
|
|
|
/* .device_count = */ ggml_backend_metal_reg_device_count,
|
|
|
|
/* .device_get = */ ggml_backend_metal_reg_device_get,
|
|
|
|
/* .get_proc_address = */ NULL,
|
|
|
|
};
|
|
|
|
|
|
|
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|
|
|
// TODO: make this thread-safe somehow?
|
|
|
|
{
|
|
|
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
|
|
|
/* .iface = */ ggml_backend_metal_reg_i,
|
|
|
|
/* .context = */ NULL,
|
|
|
|
};
|
|
|
|
|
|
|
|
g_ggml_backend_metal_device = (struct ggml_backend_device) {
|
|
|
|
/* .iface = */ ggml_backend_metal_device_i,
|
|
|
|
/* .reg = */ &g_ggml_backend_metal_reg,
|
|
|
|
/* .context = */ &g_ggml_ctx_dev_main,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
return &g_ggml_backend_metal_reg;
|
2023-12-07 21:26:54 +01:00
|
|
|
}
|