mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
ggml : support more contiguous cases
ggml-ci
This commit is contained in:
parent
f2a029bd9d
commit
cd026b48ef
45
ggml.c
45
ggml.c
@ -3212,35 +3212,42 @@ GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
||||
return tensor->nb[0] > tensor->nb[1];
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
|
||||
size_t next_nb = ggml_type_size(tensor->type);
|
||||
if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
|
||||
return false;
|
||||
}
|
||||
next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->ne[i] != 1) {
|
||||
if (i > n) {
|
||||
if (tensor->nb[i] != next_nb) {
|
||||
return false;
|
||||
}
|
||||
next_nb *= tensor->ne[i];
|
||||
} else {
|
||||
// this dimension does not need to be contiguous
|
||||
next_nb = tensor->ne[i]*tensor->nb[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return
|
||||
(tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) &&
|
||||
(tensor->ne[1] == 1 || tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type)) &&
|
||||
(tensor->ne[2] == 1 || tensor->nb[2] == (tensor->nb[1]*tensor->ne[1])) &&
|
||||
(tensor->ne[3] == 1 || tensor->nb[3] == (tensor->nb[2]*tensor->ne[2]));
|
||||
GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||
return ggml_is_contiguous_0(tensor);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
|
||||
return ggml_is_contiguous(tensor);
|
||||
return ggml_is_contiguous_n(tensor, 0);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
(tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) &&
|
||||
(tensor->ne[2] == 1 || tensor->nb[2] == tensor->nb[1]*tensor->ne[1]) &&
|
||||
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
|
||||
return ggml_is_contiguous_n(tensor, 1);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
(tensor->ne[0] == ggml_blck_size(tensor->type) || tensor->nb[0] == ggml_type_size(tensor->type)) &&
|
||||
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
|
||||
return ggml_is_contiguous_n(tensor, 2);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||
|
Loading…
Reference in New Issue
Block a user