diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp index ad5feea05..461febcc0 100644 --- a/ggml/src/ggml-cann.cpp +++ b/ggml/src/ggml-cann.cpp @@ -1559,23 +1559,18 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async( return false; } + // need open both directions for memcpyasync between devices. + ggml_cann_set_device(cann_ctx_dst->device); + ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0)); ggml_cann_set_device(cann_ctx_src->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); + ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, - cann_ctx_dst->stream())); + cann_ctx_src->stream())); - // record event on src stream - if (!cann_ctx_src->copy_event) { - ACL_CHECK(aclrtCreateEvent(&cann_ctx_src->copy_event)); - } - - ACL_CHECK( - aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); - - // wait on dst stream for the copy to complete - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), - cann_ctx_src->copy_event)); + //TODO: workaround for Event didn`t work here. + aclrtSynchronizeStream(cann_ctx_src->stream()); } else { // src and dst are on the same backend ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, @@ -1763,8 +1758,8 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { * * This function determines whether the CANN backend supports the given backend * buffer type by comparing the device context of the backend and buffer type. - * It returns true if the device associated with the buffer type matches the - * device associated with the backend. + * It returns true if the devices are same between the backend context and + * buffer type context. * * @param backend Pointer to the CANN backend. * @param buft Pointer to the backend buffer type to check. @@ -1773,9 +1768,14 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { */ GGML_CALL static bool ggml_backend_cann_supports_buft( ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; - - GGML_UNUSED(backend); + if (ggml_backend_buft_is_cann(buft)) { + ggml_backend_cann_context * cann_ctx = + (ggml_backend_cann_context *)backend->context; + ggml_backend_cann_buffer_type_context * buft_ctx = + (ggml_backend_cann_buffer_type_context *)buft->context; + return buft_ctx->device == cann_ctx->device; + } + return false; } /**