metal : add more general support for ggml_get_rows + tests

This commit is contained in:
Georgi Gerganov 2023-12-09 14:18:42 +02:00
parent 9064b1ca05
commit 2cbcba829f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 78 additions and 25 deletions

View File

@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_MUL: case GGML_OP_MUL:
@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
return true; return true;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{ {
return op->ne[0] % 4 == 0; return op->ne[0] % 4 == 0;
} }
@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
} }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
const int64_t n = ggml_nelements(src1); const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break; } break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {

View File

@ -3223,14 +3223,16 @@ kernel void kernel_get_rows(
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1, constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint tptg[[threads_per_threadgroup]]) { uint tptg [[threads_per_threadgroup]]) {
const int i = tgpig; const int64_t i = tgpig;
const int r = ((device int32_t *) src1)[i]; const int64_t r = ((device int32_t *) src1)[i];
for (int ind = tiitg; ind < ne00/16; ind += tptg) { for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
float4x4 temp; float4x4 temp;
dequantize_func( dequantize_func(
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
@ -3238,6 +3240,52 @@ kernel void kernel_get_rows(
} }
} }
kernel void kernel_get_rows_f32(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
const int64_t i02 = i/ne10;
for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
kernel void kernel_get_rows_f16(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
const int64_t i02 = i/ne10;
for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32 #define BLOCK_SIZE_K 32
@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1, constant uint64_t & nb1,
uint, uint, uint); uint, uint, uint);
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>; //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;

6
ggml.c
View File

@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(
dequantize_row_q( dequantize_row_q(
(const void *) ((char *) src0->data + i02*nb02 + r*nb01), (const void *) ((char *) src0->data + i02*nb02 + r*nb01),
(float *) ((char *) dst->data + i*dst->nb[1]), nc); (float *) ((char *) dst->data + i*nb1), nc);
} }
} }
@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(
for (int j = 0; j < nc; ++j) { for (int j = 0; j < nc; ++j) {
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j]; ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); ((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
} }
} }
} }
@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
const int64_t i02 = i/ne10; const int64_t i02 = i/ne10;
ggml_vec_cpy_f32(nc, ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*dst->nb[1]), (float *) ((char *) dst->data + i*nb1),
(float *) ((char *) src0->data + i02*nb02 + r*nb01)); (float *) ((char *) src0->data + i02*nb02 + r*nb01));
} }
} }

View File

@ -488,17 +488,18 @@ struct test_get_rows : public test_case {
const int n; // cols const int n; // cols
const int m; // rows const int m; // rows
const int r; // rows to get const int r; // rows to get
const int b; // batch size
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, n, m, r); return VARS_TO_STR4(type, n, m, r);
} }
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3) test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
: type(type), n(n), m(m), r(r) {} : type(type), n(n), m(m), r(r), b(b) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m); ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r); ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
ggml_tensor * out = ggml_get_rows(ctx, in, rows); ggml_tensor * out = ggml_get_rows(ctx, in, rows);
return out; return out;
} }
@ -507,11 +508,11 @@ struct test_get_rows : public test_case {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) { if (t->type == GGML_TYPE_I32) {
// rows // rows
std::vector<int> data(r); std::vector<int> data(r*b);
for (int i = 0; i < r; i++) { for (int i = 0; i < r*b; i++) {
data[i] = rand() % m; data[i] = rand() % m;
} }
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int)); ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else { } else {
init_tensor_uniform(t); init_tensor_uniform(t);
} }
@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3)); test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3)); test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
} }
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));