mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
metal : indentations
This commit is contained in:
parent
c60022488a
commit
8f6ad68427
@ -423,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|||||||
}
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
//Note: This is a template, but strictly speaking it only applies to
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
// quantizations where the block size is 32. It also does not
|
// quantizations where the block size is 32. It also does not
|
||||||
@ -435,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
||||||
uint3 tgpig, uint tiisg, uint sgitg) {
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
||||||
const int nb = ne00/QK4_0;
|
const int nb = ne00/QK4_0;
|
||||||
|
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * nsg + sgitg) * nr;
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||||
|
|
||||||
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
||||||
|
|
||||||
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
float yl[16]; // src1 vector cache
|
|
||||||
float sumf[nr]={0.f};
|
|
||||||
|
|
||||||
const int ix = tiisg/2;
|
float yl[16]; // src1 vector cache
|
||||||
const int il = 8*(tiisg%2);
|
float sumf[nr] = {0.f};
|
||||||
|
|
||||||
|
const int ix = (tiisg/2);
|
||||||
|
const int il = (tiisg%2)*8;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK4_0 + il;
|
device const float * yb = y + ix * QK4_0 + il;
|
||||||
|
|
||||||
@ -457,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
sumy += yb[i] + yb[i+1];
|
sumy += yb[i] + yb[i+1];
|
||||||
yl[i+0] = yb[i+ 0];
|
yl[i+0] = yb[i+ 0];
|
||||||
yl[i+1] = yb[i+ 1]/256.f;
|
yl[i+1] = yb[i+ 1]/256.f;
|
||||||
|
|
||||||
sumy += yb[i+16] + yb[i+17];
|
sumy += yb[i+16] + yb[i+17];
|
||||||
yl[i+8] = yb[i+16]/16.f;
|
yl[i+8] = yb[i+16]/16.f;
|
||||||
yl[i+9] = yb[i+17]/4096.f;
|
yl[i+9] = yb[i+17]/4096.f;
|
||||||
@ -472,7 +478,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
const float tot = simd_sum(sumf[row]);
|
const float tot = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0 && first_row + row < ne01) {
|
if (tiisg == 0 && first_row + row < ne01) {
|
||||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -490,8 +496,8 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|||||||
constant int64_t & ne1[[buffer(16)]],
|
constant int64_t & ne1[[buffer(16)]],
|
||||||
constant uint & gqa[[buffer(17)]],
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -669,7 +675,7 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
Loading…
Reference in New Issue
Block a user