mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
CUDA: fix scratch malloced on non-main device (#3220)
This commit is contained in:
parent
b541b4f0b1
commit
578d8c8f5c
@ -6970,6 +6970,7 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (g_scratch_buffer == nullptr) {
|
if (g_scratch_buffer == nullptr) {
|
||||||
|
ggml_cuda_set_device(g_main_device);
|
||||||
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user