mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 06:19:02 +01:00
metal : fix build and some more comments (#10229)
This commit is contained in:
parent
bb38cdd8ba
commit
39a334a9aa
@ -3041,6 +3041,8 @@ static void ggml_metal_encode_node(
|
||||
|
||||
bool use_vec_kernel = false;
|
||||
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
|
@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short D4 = D/4;
|
||||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short NL = NW/4;
|
||||
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
|
||||
const short SH = 2*C; // shared memory per simdgroup
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
// each simdgroup processes 1 query and 4 (NW/NL) keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||
|
||||
@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
half, half4, half4x4, \
|
||||
half4x4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
|
Loading…
Reference in New Issue
Block a user