mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
metal : unify mul_mv_id kernels (#6556)
This commit is contained in:
parent
4cc120c744
commit
fbbc030ba9
@ -1926,7 +1926,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
#if QK_K == 64
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
||||
#else
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||||
#endif
|
||||
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
1323
ggml-metal.metal
1323
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
1
ggml.c
1
ggml.c
@ -11012,7 +11012,6 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
}
|
||||
|
||||
// initialize matrix_row_counts
|
||||
GGML_ASSERT(wdata == wdata_src1_end);
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
|
||||
// group rows by src0 matrix
|
||||
|
@ -2014,6 +2014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
for (int n_mats : {2, 4, 8}) {
|
||||
for (int id = 0; id < n_mats; id++) {
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v));
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user