mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
iq3_s: minor improvement on Metal
49.4 t/s -> 50.3 t/s
This commit is contained in:
parent
9c5b594cde
commit
7b629c3b65
@ -4773,8 +4773,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
|
|
||||||
float2 sum = {0};
|
float2 sum = {0};
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
||||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
||||||
|
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
||||||
|
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||||
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||||
|
Loading…
Reference in New Issue
Block a user