diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 1dc117af2..7b48c41bd 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4307,7 +4307,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, - const sycl::nd_item<3> &item_ct1) { + uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { const block_q4_K * x = (const block_q4_K *) vx; const int i = item_ct1.get_group(2); @@ -4325,13 +4325,17 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri const float dall = dm[0]; const float dmin = dm[1]; + if (tid < 12) + scales_local[tid] = x[i].scales[tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); + const uint8_t * q = x[i].qs + 32*il + n*ir; uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); + get_scale_min_k4(is + 0, scales_local, sc, m); const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); + get_scale_min_k4(is + 1, scales_local, sc, m); const float d2 = dall * sc; const float m2 = dmin * m; for (int l = 0; l < n; ++l) { @@ -9894,12 +9898,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_K(vx, y, item_ct1); + dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1); }); + }); } }