mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 20:43:07 +01:00
Store scales in local mem
This commit is contained in:
parent
cb3fb42046
commit
604ef6bf15
@ -4307,7 +4307,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
|
|||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int i = item_ct1.get_group(2);
|
||||||
@ -4325,13 +4325,17 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|||||||
const float dall = dm[0];
|
const float dall = dm[0];
|
||||||
const float dmin = dm[1];
|
const float dmin = dm[1];
|
||||||
|
|
||||||
|
if (tid < 12)
|
||||||
|
scales_local[tid] = x[i].scales[tid];
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
||||||
|
|
||||||
uint8_t sc, m;
|
uint8_t sc, m;
|
||||||
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||||
const float d1 = dall * sc;
|
const float d1 = dall * sc;
|
||||||
const float m1 = dmin * m;
|
const float m1 = dmin * m;
|
||||||
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||||
const float d2 = dall * sc;
|
const float d2 = dall * sc;
|
||||||
const float m2 = dmin * m;
|
const float m2 = dmin * m;
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
@ -9894,12 +9898,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
|||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
sycl::range<3>(1, 1, 32),
|
sycl::range<3>(1, 1, 32),
|
||||||
sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
dequantize_block_q4_K(vx, y, item_ct1);
|
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user