Vulkan Intel Fixes, Optimizations and Debugging Flags (#5301)

* Fix Vulkan on Intel ARC

Optimize matmul for Intel ARC

Add Vulkan dequant test

* Add Vulkan debug and validate flags to Make and CMakeLists.txt

* Enable asynchronous transfers in Vulkan backend

* Fix flake8

* Disable Vulkan async backend functions for now

* Also add Vulkan run tests command to Makefile and CMakeLists.txt
This commit is contained in:
0cc4m 2024-02-03 18:15:00 +01:00 committed by GitHub
parent 52bb63c708
commit e920ed393d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1257 additions and 10330 deletions

View File

@ -100,6 +100,10 @@ option(LLAMA_HIPBLAS "llama: use hipBLAS"
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF) option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_VULKAN "llama: use Vulkan" OFF) option(LLAMA_VULKAN "llama: use Vulkan" OFF)
option(LLAMA_VULKAN_CHECK_RESULTS "llama: run Vulkan op checks" OFF)
option(LLAMA_VULKAN_DEBUG "llama: enable Vulkan debug output" OFF)
option(LLAMA_VULKAN_VALIDATE "llama: enable Vulkan validation" OFF)
option(LLAMA_VULKAN_RUN_TESTS "llama: run Vulkan tests" OFF)
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
option(LLAMA_METAL_SHADER_DEBUG "llama: compile Metal with -fno-fast-math" OFF) option(LLAMA_METAL_SHADER_DEBUG "llama: compile Metal with -fno-fast-math" OFF)
@ -431,6 +435,22 @@ if (LLAMA_VULKAN)
add_compile_definitions(GGML_USE_VULKAN) add_compile_definitions(GGML_USE_VULKAN)
if (LLAMA_VULKAN_CHECK_RESULTS)
target_compile_definitions(ggml-vulkan PRIVATE GGML_VULKAN_CHECK_RESULTS)
endif()
if (LLAMA_VULKAN_DEBUG)
target_compile_definitions(ggml-vulkan PRIVATE GGML_VULKAN_DEBUG)
endif()
if (LLAMA_VULKAN_VALIDATE)
target_compile_definitions(ggml-vulkan PRIVATE GGML_VULKAN_VALIDATE)
endif()
if (LLAMA_VULKAN_RUN_TESTS)
target_compile_definitions(ggml-vulkan PRIVATE GGML_VULKAN_RUN_TESTS)
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-vulkan) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-vulkan)
else() else()
message(WARNING "Vulkan not found") message(WARNING "Vulkan not found")

View File

@ -457,6 +457,18 @@ ifdef LLAMA_VULKAN_CHECK_RESULTS
MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS
endif endif
ifdef LLAMA_VULKAN_DEBUG
MK_CPPFLAGS += -DGGML_VULKAN_DEBUG
endif
ifdef LLAMA_VULKAN_VALIDATE
MK_CPPFLAGS += -DGGML_VULKAN_VALIDATE
endif
ifdef LLAMA_VULKAN_RUN_TESTS
MK_CPPFLAGS += -DGGML_VULKAN_RUN_TESTS
endif
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
endif # LLAMA_VULKAN endif # LLAMA_VULKAN

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -157,19 +157,10 @@ struct block_q6_K
# Dequant functions # Dequant functions
shader_f16_dequant_func = """ shader_f16_dequant_func = """
#define DEQUANT_FUNC f16vec2 v = f16vec2(data_a[ib + 0], data_a[ib + 1]);
"""
shader_f16_dequant_func_compat = """
#define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]); #define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]);
""" """
shader_q4_0_dequant_func = """ shader_q4_0_dequant_func = """
#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
const uint8_t vui = data_a[ib].qs[iqs]; \
f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
v = (v - 8.0hf)*d;
"""
shader_q4_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const uint vui = uint(data_a[ib].qs[iqs]); \ const uint vui = uint(data_a[ib].qs[iqs]); \
vec2 v = vec2(vui & 0xF, vui >> 4); \ vec2 v = vec2(vui & 0xF, vui >> 4); \
@ -177,13 +168,6 @@ v = (v - 8.0f)*d;
""" """
shader_q4_1_dequant_func = """ shader_q4_1_dequant_func = """
#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
const float16_t m = data_a[ib].m; \
const uint8_t vui = data_a[ib].qs[iqs]; \
f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
v = v*d + m;
"""
shader_q4_1_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \ const float m = float(data_a[ib].m); \
const uint vui = uint(data_a[ib].qs[iqs]); \ const uint vui = uint(data_a[ib].qs[iqs]); \
@ -192,14 +176,6 @@ v = v*d + m;
""" """
shader_q5_0_dequant_func = """ shader_q5_0_dequant_func = """
#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
const uint8_t vui = data_a[ib].qs[iqs]; \
f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = (v - 16.0hf) * d;
"""
shader_q5_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
@ -209,14 +185,6 @@ v = (v - 16.0f) * d;
""" """
shader_q5_1_dequant_func = """ shader_q5_1_dequant_func = """
#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
const float16_t m = data_a[ib].m; \
const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
const uint8_t vui = data_a[ib].qs[iqs]; \
f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = v*d + m;
"""
shader_q5_1_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \ const float m = float(data_a[ib].m); \
const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \ const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
@ -226,11 +194,6 @@ v = v*d + m;
""" """
shader_q8_0_dequant_func = """ shader_q8_0_dequant_func = """
#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
f16vec2 v = f16vec2(data_a[ib].qs[iqs], data_a[ib].qs[iqs + 1]); \
v = v * d;
"""
shader_q8_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \ vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \
v = v * d; v = v * d;
@ -2110,7 +2073,7 @@ lock = asyncio.Lock()
shader_fnames = [] shader_fnames = []
async def string_to_spv(name, code, defines, fp16): async def string_to_spv(name, code, defines, fp16=True):
f = NamedTemporaryFile(mode="w", delete=False) f = NamedTemporaryFile(mode="w", delete=False)
f.write(code) f.write(code)
f.flush() f.flush()
@ -2200,64 +2163,6 @@ async def main():
tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
# Build dequant shaders
tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}, fp16))
for i in range(0, VK_NUM_TYPES):
stream.clear()
stream.extend((dequant_head, shader_int8_ext, shader_float_type))
if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, dequant_body))
elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat if not fp16 else shader_q4_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat if not fp16 else shader_q4_1_dequant_func, dequant_body))
elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat if not fp16 else shader_q5_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat if not fp16 else shader_q5_1_dequant_func, dequant_body))
elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat if not fp16 else shader_q8_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
elif i == GGML_TYPE_Q3_K:
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, dequant_q4_K_body))
elif i == GGML_TYPE_Q5_K:
stream.extend((shader_q5_K_defines, dequant_q5_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
else:
continue
tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}, fp16))
# get_rows
for i in range(0, VK_NUM_TYPES):
stream.clear()
stream.extend((generic_head, shader_int8_ext, shader_float_type))
if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat if not fp16 else shader_q4_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat if not fp16 else shader_q4_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat if not fp16 else shader_q5_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat if not fp16 else shader_q5_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat if not fp16 else shader_q8_0_dequant_func, get_rows_body))
else:
continue
tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}, fp16))
tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}, fp16))
# Shaders where precision is needed, so no fp16 version # Shaders where precision is needed, so no fp16 version
# mul mat vec # mul mat vec
@ -2266,17 +2171,17 @@ async def main():
stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16: if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_f16_defines, shader_f16_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_0: elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_1: elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_0: elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_1: elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q8_0: elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat, mul_mat_vec_body)) stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q2_K: elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body)) stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
elif i == GGML_TYPE_Q3_K: elif i == GGML_TYPE_Q3_K:
@ -2290,43 +2195,101 @@ async def main():
else: else:
continue continue
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16)) tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, True)) # Dequant shaders
tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, True)) for i in range(0, VK_NUM_TYPES):
stream.clear()
stream.extend((dequant_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body))
elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body))
elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
elif i == GGML_TYPE_Q3_K:
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, dequant_q4_K_body))
elif i == GGML_TYPE_Q5_K:
stream.extend((shader_q5_K_defines, dequant_q5_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
else:
continue
tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
# get_rows
for i in range(0, VK_NUM_TYPES):
stream.clear()
stream.extend((generic_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, get_rows_body))
else:
continue
tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
# Norms # Norms
tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}, True)) tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True)) tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}, True)) tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True)) tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}, True)) tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True)) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)