metal : add rope_f16 kernel + optimize cpy kernels

This commit is contained in:
Georgi Gerganov 2023-09-17 23:09:48 +03:00
parent 1fb033fd85
commit fad56936d4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 62 additions and 19 deletions

View File

@ -100,7 +100,8 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@ -261,7 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(rope); GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
const int nth = 32; const int nth = MIN(32, ne00);
if (ne00%4 == 0) { if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4]; [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 512; const int nth = MIN(512, ne00);
[encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 256; const int nth = MIN(256, ne00);
[encoder setComputePipelineState:ctx->pipeline_norm]; [encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1171,6 +1174,8 @@ void ggml_metal_graph_compute(
{ {
GGML_ASSERT((src0t == GGML_TYPE_F32)); GGML_ASSERT((src0t == GGML_TYPE_F32));
const int nth = MIN(1024, ne00);
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) dst->op_params)[1]; const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias; float max_bias;
@ -1204,15 +1209,15 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
const int nth = 32;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
GGML_ASSERT(ne10 == ne02); GGML_ASSERT(ne10 == ne02);
//const int n_past = ((int32_t *) dst->op_params)[0]; const int nth = MIN(1024, ne00);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -1221,7 +1226,12 @@ void ggml_metal_graph_compute(
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_rope]; switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1241,19 +1251,19 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21]; [encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:
{ {
const int nth = 32; const int nth = MIN(1024, ne00);
switch (src0t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:

View File

@ -853,6 +853,36 @@ kernel void kernel_alibi_f32(
} }
} }
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
template<typename T>
kernel void kernel_rope( kernel void kernel_rope(
device const void * src0, device const void * src0,
device const int32_t * src1, device const int32_t * src1,
@ -901,11 +931,11 @@ kernel void kernel_rope(
const float cos_theta = cos(theta); const float cos_theta = cos(theta);
const float sin_theta = sin(theta); const float sin_theta = sin(theta);
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0]; const T x0 = src[0];
const float x1 = src[1]; const T x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
@ -920,8 +950,8 @@ kernel void kernel_rope(
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[n_dims/2]; const float x1 = src[n_dims/2];
@ -933,6 +963,9 @@ kernel void kernel_rope(
} }
} }
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,