cuda : fix mul_mat_id with multi gpu

This commit is contained in:
slaren 2023-12-11 16:53:25 +01:00
parent 33e50f1b53
commit 296c945de5

View File

@ -8361,11 +8361,16 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row.ne[1] = 1;
dst_row.ne[1] = 1;
if (src1->backend == GGML_BACKEND_GPU) {
src1_row.extra = &src1_row_extra;
}
src1_row.nb[2] = src1_row.nb[1];
dst_row.nb[2] = dst_row.nb[1];
src1_row.nb[3] = src1_row.nb[1];
dst_row.nb[3] = dst_row.nb[1];
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@ -8381,6 +8386,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row.data = (char *) src1->data + i01*src1->nb[1];
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1];
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}