metal : add cpy f16 -> f32 kernel

This commit is contained in:
Georgi Gerganov 2023-12-12 14:14:15 +02:00
parent a742d9f9b7
commit 08eb99179a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 84 additions and 15 deletions

View File

@ -155,6 +155,7 @@ struct ggml_metal_context {
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
@ -424,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
@ -539,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
@ -867,12 +870,37 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_ARGSORT:
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
{
switch (op->src[0]->type) {
case GGML_TYPE_F32:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return true;
default:
return false;
}
case GGML_TYPE_F16:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
default:
return false;
};
}
case GGML_OP_DIAG_MASK_INF:
{
return op->ne[0] % 4 == 0;
@ -2021,7 +2049,7 @@ void ggml_metal_graph_compute(
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;

View File

@ -1738,6 +1738,47 @@ kernel void kernel_cpy_f16_f16(
}
}
kernel void kernel_cpy_f16_f32(
device const half * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,