mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
cuda : fix mul_mat_id with multi gpu
This commit is contained in:
parent
33e50f1b53
commit
296c945de5
12
ggml-cuda.cu
12
ggml-cuda.cu
@ -8361,11 +8361,16 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||||||
src1_row.ne[1] = 1;
|
src1_row.ne[1] = 1;
|
||||||
dst_row.ne[1] = 1;
|
dst_row.ne[1] = 1;
|
||||||
|
|
||||||
if (src1->backend == GGML_BACKEND_GPU) {
|
src1_row.nb[2] = src1_row.nb[1];
|
||||||
src1_row.extra = &src1_row_extra;
|
dst_row.nb[2] = dst_row.nb[1];
|
||||||
}
|
|
||||||
|
src1_row.nb[3] = src1_row.nb[1];
|
||||||
|
dst_row.nb[3] = dst_row.nb[1];
|
||||||
|
|
||||||
|
src1_row.extra = &src1_row_extra;
|
||||||
dst_row.extra = &dst_row_extra;
|
dst_row.extra = &dst_row_extra;
|
||||||
|
|
||||||
|
|
||||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||||
//int32_t row_id;
|
//int32_t row_id;
|
||||||
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||||
@ -8381,6 +8386,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||||||
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
||||||
|
|
||||||
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
||||||
|
dst_row.data = (char *) dst->data + i01*dst->nb[1];
|
||||||
|
|
||||||
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user