metal : revert 6af0bab until we fix it

This restores the generated text to be the same as before #2959
This commit is contained in:
Georgi Gerganov 2023-09-03 12:40:56 +03:00
parent afc43d5f82
commit d9151e6f57
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -536,27 +536,14 @@ kernel void kernel_mul_mat_f16_f32_1row(
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
if (ne00 < 128) { for (int i = tiisg; i < ne00; i += 32) {
for (int i = tiisg; i < ne00; i += 32) { sumf += (float) x[i] * (float) y[i];
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} else {
device const half4 * x4 = (device const half4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} }
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} }
#define N_F16_F32 4 #define N_F16_F32 4
@ -588,49 +575,24 @@ kernel void kernel_mul_mat_f16_f32(
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) { for (int row = 0; row < N_F16_F32; ++row) {
for (int row = 0; row < N_F16_F32; ++row) { int r1 = rb + row;
int r1 = rb + row; if (r1 >= ne11) {
if (r1 >= ne11) { break;
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} }
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; sumf += (float) x[i] * (float) y[i];
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} }
} }
} }
kernel void kernel_alibi_f32( kernel void kernel_alibi_f32(