mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 12:33:06 +01:00
metal : reorder write loop
This commit is contained in:
parent
5b359bb1e3
commit
535050572a
@ -6410,11 +6410,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||
if (sgitg == 0) {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n_rows/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < n_rows; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user