mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
metal : add missing barriers for mul-mat (#2699)
This commit is contained in:
parent
226255b44e
commit
14b1d7e6f7
@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
//load data and store to threadgroup memory
|
//load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
half4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
#pragma unroll(16)
|
#pragma unroll(16)
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||||
@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
|
||||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||||
if (sgitg==0) {
|
if (sgitg==0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user