cann: Fix Multi-NPU execution error (#8710)

* cann: fix multi-npu exec error

* cann: update comment  for ggml_backend_cann_supports_buft
This commit is contained in:
wangshuai09 2024-07-27 16:36:44 +08:00 committed by GitHub
parent 2b1f616b20
commit bfb4c74981
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1559,23 +1559,18 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
return false; return false;
} }
// need open both directions for memcpyasync between devices.
ggml_cann_set_device(cann_ctx_dst->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
ggml_cann_set_device(cann_ctx_src->device); ggml_cann_set_device(cann_ctx_src->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE, ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_dst->stream())); cann_ctx_src->stream()));
// record event on src stream //TODO: workaround for Event didn`t work here.
if (!cann_ctx_src->copy_event) { aclrtSynchronizeStream(cann_ctx_src->stream());
ACL_CHECK(aclrtCreateEvent(&cann_ctx_src->copy_event));
}
ACL_CHECK(
aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
// wait on dst stream for the copy to complete
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(),
cann_ctx_src->copy_event));
} else { } else {
// src and dst are on the same backend // src and dst are on the same backend
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@ -1763,8 +1758,8 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
* *
* This function determines whether the CANN backend supports the given backend * This function determines whether the CANN backend supports the given backend
* buffer type by comparing the device context of the backend and buffer type. * buffer type by comparing the device context of the backend and buffer type.
* It returns true if the device associated with the buffer type matches the * It returns true if the devices are same between the backend context and
* device associated with the backend. * buffer type context.
* *
* @param backend Pointer to the CANN backend. * @param backend Pointer to the CANN backend.
* @param buft Pointer to the backend buffer type to check. * @param buft Pointer to the backend buffer type to check.
@ -1773,9 +1768,14 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
*/ */
GGML_CALL static bool ggml_backend_cann_supports_buft( GGML_CALL static bool ggml_backend_cann_supports_buft(
ggml_backend_t backend, ggml_backend_buffer_type_t buft) { ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name; if (ggml_backend_buft_is_cann(buft)) {
ggml_backend_cann_context * cann_ctx =
GGML_UNUSED(backend); (ggml_backend_cann_context *)backend->context;
ggml_backend_cann_buffer_type_context * buft_ctx =
(ggml_backend_cann_buffer_type_context *)buft->context;
return buft_ctx->device == cann_ctx->device;
}
return false;
} }
/** /**