wipwipwiwpip

This commit is contained in:
Georgi Gerganov 2024-05-27 12:04:09 +03:00
parent fc59407efe
commit ddc59e8e0a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 132 additions and 1 deletions

View File

@ -187,6 +187,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
return ctx->support_simdgroup_reduction && return ctx->support_simdgroup_reduction &&
@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute(
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name); // ggml_is_contiguous(src1), src1->name);
//} //}
//if (src2) {
// GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22,
// ggml_is_contiguous(src2), src2->name);
//}
//if (dst) { //if (dst) {
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name); // dst->name);
@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} }
} break; } break;
case GGML_OP_SSM_CONV:
{
id<MTLComputePipelineState> pipeline = nil;
//pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
//[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
//[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
//[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
//[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
//[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
//[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
//[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
//[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
//[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
//[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
//[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
//[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
//[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
//[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
//[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
//[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
//[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
//[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
//[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
//[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
//[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
//[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
//[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
//[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
//[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
//if (bcast_row) {
// const int64_t n = ggml_nelements(dst)/4;
// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
//} else {
// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
//}
} break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:

View File

@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
kernel void kernel_ssm_conv_f32(
device const float * src0,
device const float * src1,
device const float * src2,
device const int32_t * src3,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne11,
constant int64_t & ne20,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb21,
constant uint64_t & nb22,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
}
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,

View File

@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case {
} }
}; };
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type_s;
const ggml_type type_x;
const ggml_type type_c;
const ggml_type type_sq;
const int64_t d_inner;
const int64_t d_conv;
const int64_t n_tokens;
const int64_t n_rs;
std::string vars() override {
return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs);
}
test_ssm_conv(ggml_type type_s = GGML_TYPE_F32,
ggml_type type_x = GGML_TYPE_F32,
ggml_type type_c = GGML_TYPE_F32,
ggml_type type_sq = GGML_TYPE_I32,
int64_t d_inner = 10,
int64_t d_conv = 10,
int64_t n_tokens = 10,
int64_t n_rs = 10)
: type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs);
ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens);
ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner);
ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// pos
std::vector<int> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = rand() % n_rs;
}
ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
enum llm_norm_type { enum llm_norm_type {
LLM_NORM, LLM_NORM,
LLM_NORM_RMS, LLM_NORM_RMS,
@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
} }
test_cases.emplace_back(new test_ssm_conv());
// these tests are disabled to save execution time, but they can be handy for debugging // these tests are disabled to save execution time, but they can be handy for debugging
#if 0 #if 0
test_cases.emplace_back(new test_llama(1)); test_cases.emplace_back(new test_llama(1));