mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
metal : add more general support for ggml_get_rows + tests
This commit is contained in:
parent
9064b1ca05
commit
2cbcba829f
10
ggml-metal.m
10
ggml-metal.m
@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MUL:
|
||||
@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[0] % 4 == 0;
|
||||
}
|
||||
@ -1573,11 +1573,13 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
|
@ -3223,14 +3223,16 @@ kernel void kernel_get_rows(
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int i = tgpig;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
|
||||
float4x4 temp;
|
||||
dequantize_func(
|
||||
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
||||
@ -3238,6 +3240,52 @@ kernel void kernel_get_rows(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f16(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
||||
|
6
ggml.c
6
ggml.c
@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(
|
||||
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
||||
(float *) ((char *) dst->data + i*nb1), nc);
|
||||
}
|
||||
}
|
||||
|
||||
@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(
|
||||
|
||||
for (int j = 0; j < nc; ++j) {
|
||||
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
|
||||
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
|
||||
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
ggml_vec_cpy_f32(nc,
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]),
|
||||
(float *) ((char *) dst->data + i*nb1),
|
||||
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
|
||||
}
|
||||
}
|
||||
|
@ -488,17 +488,18 @@ struct test_get_rows : public test_case {
|
||||
const int n; // cols
|
||||
const int m; // rows
|
||||
const int r; // rows to get
|
||||
const int b; // batch size
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, n, m, r);
|
||||
}
|
||||
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
|
||||
: type(type), n(n), m(m), r(r) {}
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
|
||||
: type(type), n(n), m(m), r(r), b(b) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
|
||||
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
|
||||
ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
|
||||
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
|
||||
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
|
||||
return out;
|
||||
}
|
||||
@ -507,11 +508,11 @@ struct test_get_rows : public test_case {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// rows
|
||||
std::vector<int> data(r);
|
||||
for (int i = 0; i < r; i++) {
|
||||
std::vector<int> data(r*b);
|
||||
for (int i = 0; i < r*b; i++) {
|
||||
data[i] = rand() % m;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
|
||||
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
|
||||
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
|
||||
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
|
||||
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
|
||||
|
Loading…
Reference in New Issue
Block a user