mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
metal : implement q5_0 and q5_1 kernels (#3648)
* metal : implement dequantize_q5_0 * metal : block_q_n_dot_y for block_q5_0 (broken) * metal : revert unnecessary change * metal : implement dequantize_q5_1 * metal : block_q_n_dot_y for q5_1 (broken) * metal : fix block_q_n_dot_y * minor : spaces / formatting --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
1117d06607
commit
c67fe68e41
47
ggml-metal.m
47
ggml-metal.m
@ -73,6 +73,8 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
||||||
@ -87,6 +89,8 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
||||||
@ -97,6 +101,8 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
||||||
@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
||||||
@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
||||||
@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
||||||
@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
||||||
@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
||||||
@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
||||||
@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
||||||
@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
|
|||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne02 == 1);
|
||||||
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
|
nth0 = 8;
|
||||||
|
nth1 = 8;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne02 == 1);
|
||||||
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
|
nth0 = 8;
|
||||||
|
nth1 = 8;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
|
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
||||||
|
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
||||||
|
163
ggml-metal.metal
163
ggml-metal.metal
@ -18,6 +18,21 @@ typedef struct {
|
|||||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
|
#define QK5_0 32
|
||||||
|
typedef struct {
|
||||||
|
half d; // delta
|
||||||
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
|
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
||||||
|
} block_q5_0;
|
||||||
|
|
||||||
|
#define QK5_1 32
|
||||||
|
typedef struct {
|
||||||
|
half d; // delta
|
||||||
|
half m; // min
|
||||||
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
|
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||||
|
} block_q5_1;
|
||||||
|
|
||||||
#define QK8_0 32
|
#define QK8_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
|
|||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|||||||
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float m = qb_curr->m;
|
float m = qb_curr->m;
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|||||||
return d * (acc[0] + acc[1]) + sumy * m;
|
return d * (acc[0] + acc[1]) + sumy * m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
|
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
|
float d = qb_curr->d;
|
||||||
|
|
||||||
|
float2 acc = 0.f;
|
||||||
|
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
||||||
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i+=2) {
|
||||||
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||||
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||||
|
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
||||||
|
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||||
|
}
|
||||||
|
return d * (sumy * -16.f + acc[0] + acc[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
|
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
|
float d = qb_curr->d;
|
||||||
|
float m = qb_curr->m;
|
||||||
|
|
||||||
|
float2 acc = 0.f;
|
||||||
|
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
||||||
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i+=2) {
|
||||||
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||||
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||||
|
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
||||||
|
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||||
|
}
|
||||||
|
return d * (acc[0] + acc[1]) + sumy * m;
|
||||||
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|||||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_mv_q5_0_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01[[buffer(4)]],
|
||||||
|
constant int64_t & ne02[[buffer(5)]],
|
||||||
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant int64_t & ne1[[buffer(16)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_mv_q5_1_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01[[buffer(4)]],
|
||||||
|
constant int64_t & ne02[[buffer(5)]],
|
||||||
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant int64_t & ne1[[buffer(16)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#define NB_Q8_0 8
|
#define NB_Q8_0 8
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q8_0_f32(
|
kernel void kernel_mul_mv_q8_0_f32(
|
||||||
@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||||
|
const float d = xb->d;
|
||||||
|
const float md = -16.h * xb->d;
|
||||||
|
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||||
|
|
||||||
|
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||||
|
|
||||||
|
const int x_mv = il ? 4 : 0;
|
||||||
|
|
||||||
|
const int gh_mv = il ? 12 : 0;
|
||||||
|
const int gh_bk = il ? 0 : 4;
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
// extract the 5-th bits for x0 and x1
|
||||||
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
|
// combine the 4-bits from qs with the 5th bit
|
||||||
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
|
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
||||||
|
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||||
|
const float d = xb->d;
|
||||||
|
const float m = xb->m;
|
||||||
|
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||||
|
|
||||||
|
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||||
|
|
||||||
|
const int x_mv = il ? 4 : 0;
|
||||||
|
|
||||||
|
const int gh_mv = il ? 12 : 0;
|
||||||
|
const int gh_bk = il ? 0 : 4;
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
// extract the 5-th bits for x0 and x1
|
||||||
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
|
// combine the 4-bits from qs with the 5th bit
|
||||||
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
|
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
||||||
|
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|||||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||||
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
||||||
|
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||||
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
||||||
|
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
|
Loading…
Reference in New Issue
Block a user