diff --git a/ggml-metal.metal b/ggml-metal.metal index 8d333745d..44efbbf17 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -423,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre } // putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not @@ -435,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache - float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = 8*(tiisg%2); + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; device const float * yb = y + ix * QK4_0 + il; @@ -457,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device sumy += yb[i] + yb[i+1]; yl[i+0] = yb[i+ 0]; yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; yl[i+8] = yb[i+16]/16.f; yl[i+9] = yb[i+17]/4096.f; @@ -472,7 +478,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; } } } @@ -490,8 +496,8 @@ kernel void kernel_mul_mv_q4_0_f32( constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } @@ -669,7 +675,7 @@ kernel void kernel_mul_mv_f16_f32_1row( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y;