mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-04 01:57:53 +01:00
metal : add F32 -> Q8_0 copy kernel
This commit is contained in:
parent
d04ee928a2
commit
bcfebf241d
14
ggml-metal.m
14
ggml-metal.m
@ -118,6 +118,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(concat);
|
||||
GGML_METAL_DECL_KERNEL(sqr);
|
||||
@ -324,6 +325,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(concat);
|
||||
GGML_METAL_ADD_KERNEL(sqr);
|
||||
@ -425,6 +427,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(concat);
|
||||
GGML_METAL_DEL_KERNEL(sqr);
|
||||
@ -1549,14 +1552,19 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
const int nth = MIN(1024, ne00);
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
|
||||
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
||||
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
} break;
|
||||
|
@ -1460,6 +1460,64 @@ kernel void kernel_cpy_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q8_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
||||
|
||||
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = src[j];
|
||||
amax = MAX(amax, fabs(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK8_0].d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = src[j]*id;
|
||||
|
||||
dst_data[i00/QK8_0].qs[j] = round(x0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_concat(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -8738,7 +8738,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
//const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// TODO: move as params
|
||||
const ggml_type k_type = GGML_TYPE_Q4_0;
|
||||
const ggml_type k_type = GGML_TYPE_Q8_0;
|
||||
const ggml_type v_type = GGML_TYPE_F16;
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0);
|
||||
|
Loading…
Reference in New Issue
Block a user