metal : fix ggml_get_rows to work with non-cont src1

This commit is contained in:
Georgi Gerganov 2023-12-10 09:38:21 +02:00
parent 0710b0f726
commit 016f9bb55a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 54 additions and 30 deletions

View File

@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7]; [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
const int64_t n = ggml_nelements(src1); [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break; } break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {

View File

@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows( kernel void kernel_get_rows(
device const void * src0, device const void * src0,
device const int * src1, device const char * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1, constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]], constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) { uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig; //const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i]; //const int64_t r = ((device int32_t *) src1)[i];
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) { const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp; float4x4 temp;
dequantize_func( dequantize_func(
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
} }
} }
kernel void kernel_get_rows_f32( kernel void kernel_get_rows_f32(
device const void * src0, device const void * src0,
device const int * src1, device const char * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1, constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]], constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) { uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig; const int64_t i10 = tgpig.x;
const int64_t r = ((device int32_t *) src1)[i]; const int64_t i11 = tgpig.y;
const int64_t i02 = i/ne10; const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
for (int ind = tiitg; ind < ne00; ind += tptg) { const int64_t i02 = i11;
((device float *) ((device char *) dst + i*nb1))[ind] =
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
} }
} }
kernel void kernel_get_rows_f16( kernel void kernel_get_rows_f16(
device const void * src0, device const void * src0,
device const int * src1, device const char * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1, constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]], constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) { uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig; const int64_t i10 = tgpig.x;
const int64_t r = ((device int32_t *) src1)[i]; const int64_t i11 = tgpig.y;
const int64_t i02 = i/ne10; const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
for (int ind = tiitg; ind < ne00; ind += tptg) { const int64_t i02 = i11;
((device float *) ((device char *) dst + i*nb1))[ind] =
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
} }
} }
@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(
typedef void (get_rows_t)( typedef void (get_rows_t)(
device const void * src0, device const void * src0,
device const int * src1, device const char * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1, constant uint64_t & nb1,
uint, uint, uint); constant uint64_t & nb2,
uint3, uint, uint3);
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>; //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;