mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-03 15:23:02 +01:00
Quite significant PP speedup on metal
This commit is contained in:
parent
e3ff8c20c8
commit
2b601702a8
11
ggml-metal.m
11
ggml-metal.m
@ -906,8 +906,8 @@ void ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
GGML_ASSERT(ne12 == 1);
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 4; //1;
|
||||||
nth1 = 32;
|
nth1 = 8; //32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
@ -955,9 +955,12 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
@ -972,7 +975,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
//[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
//[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -505,6 +505,8 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define N_F16_F32 4
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -527,20 +529,28 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t rb = N_F16_F32*tgpig.y;
|
||||||
const int64_t im = tgpig.z;
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
|
|
||||||
float sumf = 0;
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
for (int i = tiisg; i < ne00; i += 32) {
|
int r1 = rb + row;
|
||||||
sumf += (float) x[i] * (float) y[i];
|
if (r1 >= ne11) {
|
||||||
}
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
if (tiisg == 0) {
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (float) x[i] * (float) y[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1241,7 +1251,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
|
const int first_row = r0 * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
const uint offset0 = r2/gqa*(nb*ne0);
|
const uint offset0 = r2/gqa*(nb*ne0);
|
||||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||||
|
Loading…
Reference in New Issue
Block a user