ggml-alloc : fix backend assignments of views (#3982)

This commit is contained in:
slaren 2023-11-08 13:15:14 +01:00 committed by GitHub
parent 0a7c980b6f
commit 875fb42871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -378,9 +378,13 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
} }
} }
static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) { static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view, bool update_backend) {
assert(view->view_src != NULL && view->view_src->data != NULL); assert(view->view_src != NULL && view->view_src->data != NULL);
if (update_backend) {
view->backend = view->view_src->backend; view->backend = view->view_src->backend;
}
view->buffer = view->view_src->buffer; view->buffer = view->view_src->buffer;
view->data = (char *)view->view_src->data + view->view_offs; view->data = (char *)view->view_src->data + view->view_offs;
@ -394,7 +398,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
struct hash_node * ht = alloc->hash_table; struct hash_node * ht = alloc->hash_table;
if (node->data == NULL) { if (node->data == NULL) {
if (ggml_is_view(node)) { if (ggml_is_view(node)) {
init_view(alloc, node); init_view(alloc, node, true);
} else { } else {
// see if we can reuse a parent's buffer (inplace) // see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) { if (ggml_op_can_inplace(node->op)) {
@ -424,15 +428,14 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->view_src = view_src; node->view_src = view_src;
view_src_hn->n_views += 1; view_src_hn->n_views += 1;
init_view(alloc, node); init_view(alloc, node, false);
return; return;
} }
} } else {
else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->view_src = parent; node->view_src = parent;
p_hn->n_views += 1; p_hn->n_views += 1;
init_view(alloc, node); init_view(alloc, node, false);
return; return;
} }
} }
@ -463,7 +466,7 @@ size_t ggml_allocr_alloc_graph_n(
hash_get(ht, view_src)->n_views += 1; hash_get(ht, view_src)->n_views += 1;
if (node->buffer == NULL && node->data != NULL) { if (node->buffer == NULL && node->data != NULL) {
// view of a pre-allocated tensor, didn't call init_view() yet // view of a pre-allocated tensor, didn't call init_view() yet
init_view(alloc, node); init_view(alloc, node, true);
} }
} }
@ -474,7 +477,7 @@ size_t ggml_allocr_alloc_graph_n(
} }
hash_get(ht, parent)->n_children += 1; hash_get(ht, parent)->n_children += 1;
if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
init_view(alloc, parent); init_view(alloc, parent, true);
} }
} }
} }