mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
~4-5% improvement for Q8_0 TG on metal
This commit is contained in:
parent
363f0bf558
commit
6af0bab347
@ -443,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NB_Q8_0 8
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q8_0_f32(
|
kernel void kernel_mul_mat_q8_0_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
@ -471,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
float yl[16];
|
float yl[NB_Q8_0];
|
||||||
float sumf[nr]={0.f};
|
float sumf[nr]={0.f};
|
||||||
|
|
||||||
const int ix = tiisg/2;
|
const int ix = tiisg/4;
|
||||||
const int il = tiisg%2;
|
const int il = tiisg%4;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK8_0 + 16*il;
|
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
||||||
|
|
||||||
// each thread in a SIMD group deals with half a block.
|
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
||||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
for (int ib = ix; ib < nb; ib += nw/4) {
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < NB_Q8_0; ++i) {
|
||||||
yl[i] = yb[i];
|
yl[i] = yb[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < nr; row++) {
|
for (int row = 0; row < nr; row++) {
|
||||||
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||||
float sumq = 0.f;
|
float sumq = 0.f;
|
||||||
for (int iq = 0; iq < 16; ++iq) {
|
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||||
sumq += qs[iq] * yl[iq];
|
sumq += qs[iq] * yl[iq];
|
||||||
}
|
}
|
||||||
sumf[row] += sumq*x[ib+row*nb].d;
|
sumf[row] += sumq*x[ib+row*nb].d;
|
||||||
}
|
}
|
||||||
|
|
||||||
yb += QK8_0 * 16;
|
yb += NB_Q8_0 * nw;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user