mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
ggml-cuda: Adding support for unified memory (#8035)
* Adding support for unified memory * adding again the documentation about unified memory * refactoring: Moved the unified memory code in the correct location. * Fixed compilation error when using hipblas * cleaning up the documentation * Updating the documentation Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * adding one more case where the PR should not be enabled --------- Co-authored-by: matteo serva <matteo.serva@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
b7a08fd5e0
commit
afbb4c1322
@ -178,7 +178,11 @@ For Jetson user, if you have Jetson Orin, you can try this: [Offical Support](ht
|
|||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
|
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
|
||||||
|
|
||||||
|
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
|
||||||
|
|
||||||
|
The following compilation options are also available to tweak performance:
|
||||||
|
|
||||||
| Option | Legal values | Default | Description |
|
| Option | Legal values | Default | Description |
|
||||||
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
@ -130,7 +130,22 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||||
|
cudaError_t err;
|
||||||
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||||
|
{
|
||||||
|
err = cudaMallocManaged(ptr, size);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
err = cudaMalloc(ptr, size);
|
||||||
|
}
|
||||||
|
return err;
|
||||||
|
#else
|
||||||
return cudaMalloc(ptr, size);
|
return cudaMalloc(ptr, size);
|
||||||
|
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user