From 296c945de5fa1d36aa3680b58a84096733869c04 Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 11 Dec 2023 16:53:25 +0100 Subject: [PATCH] cuda : fix mul_mat_id with multi gpu --- ggml-cuda.cu | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 382897d59..9e1acd3f1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -8361,11 +8361,16 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s src1_row.ne[1] = 1; dst_row.ne[1] = 1; - if (src1->backend == GGML_BACKEND_GPU) { - src1_row.extra = &src1_row_extra; - } + src1_row.nb[2] = src1_row.nb[1]; + dst_row.nb[2] = dst_row.nb[1]; + + src1_row.nb[3] = src1_row.nb[1]; + dst_row.nb[3] = dst_row.nb[1]; + + src1_row.extra = &src1_row_extra; dst_row.extra = &dst_row_extra; + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { //int32_t row_id; //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); @@ -8381,6 +8386,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s src1_row.data = (char *) src1->data + i01*src1->nb[1]; dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1]; + dst_row.data = (char *) dst->data + i01*dst->nb[1]; ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row); }