diff --git a/ggml.c b/ggml.c index bed90c844..5fb9e9a32 100644 --- a/ggml.c +++ b/ggml.c @@ -3212,35 +3212,42 @@ GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } -GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); +static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { + size_t next_nb = ggml_type_size(tensor->type); + if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { + return false; + } + next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + if (tensor->ne[i] != 1) { + if (i > n) { + if (tensor->nb[i] != next_nb) { + return false; + } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; + } + } + } + return true; +} - return - (tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) && - (tensor->ne[1] == 1 || tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type)) && - (tensor->ne[2] == 1 || tensor->nb[2] == (tensor->nb[1]*tensor->ne[1])) && - (tensor->ne[3] == 1 || tensor->nb[3] == (tensor->nb[2]*tensor->ne[2])); +GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_0(tensor); } GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { - return ggml_is_contiguous(tensor); + return ggml_is_contiguous_n(tensor, 0); } GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) && - (tensor->ne[2] == 1 || tensor->nb[2] == tensor->nb[1]*tensor->ne[1]) && - (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); + return ggml_is_contiguous_n(tensor, 1); } GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) && - (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); + return ggml_is_contiguous_n(tensor, 2); } GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {