ggml : fix ARM build + speed-up ggml_mul

This commit is contained in:
Georgi Gerganov 2023-07-28 16:31:59 +03:00
parent a4d1eb72c6
commit e5d23f2e7e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 13 additions and 14 deletions

21
ggml.c
View File

@ -2681,7 +2681,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
summs += Q4_1M(x0->dm) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
summs += Q4_1M(x0->dm) * y0->s + Q4_1M(x1->dm) * y1->s;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -8898,6 +8898,13 @@ static void ggml_compute_forward_mul_f32(
const int64_t nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT( nb0 == sizeof(float));
@ -8905,7 +8912,7 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT(ne00 == ne10);
if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
for (int64_t ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@ -8919,19 +8926,11 @@ static void ggml_compute_forward_mul_f32(
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_vec_mul_f32);
vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
// }
// }
}
} else {
// src1 is not contiguous
for (int64_t ir = ith; ir < nr; ir += nth) {
for (int64_t ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);

6
ggml.h
View File

@ -281,9 +281,9 @@ extern "C" {
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
GGML_TYPE_I8 = 16,
GGML_TYPE_I16 = 17,
GGML_TYPE_I32 = 18,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_TYPE_COUNT,
};