metal : update support condition for im2col + fix warning (#0)

This commit is contained in:
Georgi Gerganov 2024-09-08 09:57:57 +03:00
parent 385decbd63
commit a876861455
2 changed files with 4 additions and 3 deletions

View File

@ -799,8 +799,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
return ctx->support_simdgroup_reduction; return ctx->support_simdgroup_reduction;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_IM2COL:
return true; return true;
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D: case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
return false; return false;

View File

@ -24,6 +24,7 @@
#include <cfloat> #include <cfloat>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <cinttypes>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random> #include <random>
@ -33,7 +34,6 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
// static RNG initialization (revisit if n_threads stops being constant) // static RNG initialization (revisit if n_threads stops being constant)
static const size_t n_threads = std::thread::hardware_concurrency(); static const size_t n_threads = std::thread::hardware_concurrency();
@ -869,7 +869,7 @@ struct test_case {
for (int64_t i = 0; i < ne; ++i) { // gradient algebraic for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
// check for nans // check for nans
if (!std::isfinite(ga[i])) { if (!std::isfinite(ga[i])) {
printf("[%s] nonfinite gradient at index %zu (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]); printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
ok = false; ok = false;
break; break;
} }