diff --git a/ggml-metal.m b/ggml-metal.m index a84b88dda..58d6e567f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - const int64_t n = ggml_nelements(src1); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_RMS_NORM: { diff --git a/ggml-metal.metal b/ggml-metal.metal index f25e813c2..6026316e5 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg template kernel void kernel_get_rows( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg [[threads_per_threadgroup]]) { - const int64_t i = tgpig; - const int64_t r = ((device int32_t *) src1)[i]; + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { float4x4 temp; dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; } } kernel void kernel_get_rows_f32( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg [[threads_per_threadgroup]]) { - const int64_t i = tgpig; - const int64_t r = ((device int32_t *) src1)[i]; + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; - const int64_t i02 = i/ne10; + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - for (int ind = tiitg; ind < ne00; ind += tptg) { - ((device float *) ((device char *) dst + i*nb1))[ind] = + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; } } kernel void kernel_get_rows_f16( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg [[threads_per_threadgroup]]) { - const int64_t i = tgpig; - const int64_t r = ((device int32_t *) src1)[i]; + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; - const int64_t i02 = i/ne10; + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - for (int ind = tiitg; ind < ne00; ind += tptg) { - ((device float *) ((device char *) dst + i*nb1))[ind] = + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; } } @@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id( typedef void (get_rows_t)( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint, uint, uint); + constant uint64_t & nb2, + uint3, uint, uint3); //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;