Vulkan: Fix float16 use on devices without float16 support + fix subgroup_size_control validation error (#11161)

* Vulkan: Remove float16 use in shaders

* Fix validation error about subgroup_size_control extension
This commit is contained in:
0cc4m 2025-01-10 06:39:33 +01:00 committed by GitHub
parent ee7136c6d1
commit c3f9d25706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 50 additions and 51 deletions

View File

@ -2277,6 +2277,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->subgroup_size_control) { if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
device_extensions.push_back("VK_EXT_subgroup_size_control");
} }
device->subgroup_size_control = device->subgroup_size_control && device->subgroup_size_control = device->subgroup_size_control &&
@ -2285,7 +2286,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->subgroup_size_control) { if (device->subgroup_size_control) {
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
device_extensions.push_back("VK_EXT_subgroup_size_control");
} }
#if defined(VK_KHR_cooperative_matrix) #if defined(VK_KHR_cooperative_matrix)

View File

@ -1,9 +1,6 @@
#version 450 #version 450
#ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
#if K_PER_ITER == 8 #if K_PER_ITER == 8
#if QUANT_R == 2 #if QUANT_R == 2
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else #else

View File

@ -1,5 +1,5 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d; vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = d.y; const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec2 qs16 = uvec2(unpack8(qs16_u16)); uvec2 qs16 = uvec2(unpack8(qs16_u16));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);

View File

@ -1,5 +1,5 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -60,14 +60,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) { [[unroll]] for (int l = 0; l < 2; ++l) {

View File

@ -1,6 +1,6 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -45,7 +45,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d; vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
@ -96,10 +96,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_15 = qs64_hi4.w; const uint32_t q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));

View File

@ -1,6 +1,6 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -42,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d; vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
@ -105,14 +105,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_15 = qs64_80_hi4.w; const uint32_t q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx = const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.x), q4_0,

View File

@ -1,6 +1,6 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
@ -77,10 +77,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec4 q3 = uvec4(unpack8(q3_u32)); uvec4 q3 = uvec4(unpack8(q3_u32));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) { [[unroll]] for (int l = 0; l < 4; ++l) {

View File

@ -1,6 +1,5 @@
#version 450 #version 450
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter layout (push_constant) uniform parameter

View File

@ -2,7 +2,10 @@
#if !defined(GGML_TYPES_COMP) #if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP #define GGML_TYPES_COMP
#extension GL_EXT_shader_explicit_arithmetic_types : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_16bit_storage : require
#if defined(DATA_A_F32) #if defined(DATA_A_F32)
#define QUANT_K 1 #define QUANT_K 1