mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 22:08:46 +01:00
metal : add Q4_1 implementation (#1785)
23.3 ms / token, so just ~1% slower than q4_0. Achieves 290 GB/s memory throughput. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
4f0154b0ba
commit
e9b66ee982
16
ggml-metal.m
16
ggml-metal.m
@ -50,12 +50,14 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
||||||
@ -141,12 +143,14 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
||||||
@ -545,6 +549,15 @@ void ggml_metal_graph_compute(
|
|||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne02 == 1);
|
||||||
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
|
nth0 = 8;
|
||||||
|
nth1 = 8;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
@ -596,7 +609,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0) {
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
|
||||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else if (src0t == GGML_TYPE_Q2_K) {
|
} else if (src0t == GGML_TYPE_Q2_K) {
|
||||||
@ -623,6 +636,7 @@ void ggml_metal_graph_compute(
|
|||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
|
||||||
|
123
ggml-metal.metal
123
ggml-metal.metal
@ -11,6 +11,13 @@ typedef struct {
|
|||||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||||
} block_q4_0;
|
} block_q4_0;
|
||||||
|
|
||||||
|
#define QK4_1 32
|
||||||
|
typedef struct {
|
||||||
|
half d; // delta
|
||||||
|
half m; // min
|
||||||
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
|
} block_q4_1;
|
||||||
|
|
||||||
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
||||||
const int qk = QK4_0;
|
const int qk = QK4_0;
|
||||||
|
|
||||||
@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
|
||||||
|
const int qk = QK4_1;
|
||||||
|
|
||||||
|
assert(k % qk == 0);
|
||||||
|
|
||||||
|
const int nb = k / qk;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
const half d = x[i].d;
|
||||||
|
const half m = x[i].m;
|
||||||
|
|
||||||
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
|
const int x0 = (x[i].qs[j] & 0x0F);
|
||||||
|
const int x1 = (x[i].qs[j] >> 4);
|
||||||
|
|
||||||
|
y[i*qk + j + 0 ] = x0*d + m;
|
||||||
|
y[i*qk + j + qk/2] = x1*d + m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_add(
|
kernel void kernel_add(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
|
|||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_get_rows_q4_1(
|
||||||
|
device const void * src0,
|
||||||
|
device const int * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const int i = tpig;
|
||||||
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
|
dequantize_row_q4_1(
|
||||||
|
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
|
||||||
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_rms_norm(
|
kernel void kernel_rms_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
threadgroup float * sum [[threadgroup(0)]],
|
||||||
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint2 tpig[[thread_position_in_grid]],
|
||||||
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint2 tptg[[threads_per_threadgroup]]) {
|
||||||
|
const int nb = ne00/QK4_1;
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
|
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
||||||
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
|
const uint nth = tptg.x*tptg.y;
|
||||||
|
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
|
const int ix = tpitg.y/4; // 0 or 1
|
||||||
|
const int iy = tpitg.y - 4*ix; // 0...3
|
||||||
|
|
||||||
|
const int first = 4 * iy;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
||||||
|
|
||||||
|
const float d = (float)x[i].d;
|
||||||
|
const float m = (float)x[i].m;
|
||||||
|
|
||||||
|
device const uint8_t * xl = x[i].qs + first;
|
||||||
|
device const float * yl = y + i * QK4_1 + first;
|
||||||
|
|
||||||
|
float2 acc = {0.0f, 0.0f};
|
||||||
|
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
|
||||||
|
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
||||||
|
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += acc[0] + acc[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
sum[ith] = sumf;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Accumulate the sum from all threads in the threadgroup
|
||||||
|
//
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (ith%4 == 0) {
|
||||||
|
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (ith%16 == 0) {
|
||||||
|
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (ith == 0) {
|
||||||
|
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||||
|
dst[r1*ne0 + r0] = sum[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
Loading…
Reference in New Issue
Block a user