diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe3b2cdfa..7d08574f5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -56,6 +56,7 @@ jobs: mkdir build cd build cmake .. \ + -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL_USE_BF16=ON \ @@ -120,6 +121,7 @@ jobs: # Metal is disabled due to intermittent failures with Github runners not having a GPU: # https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL=OFF \ @@ -160,8 +162,8 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip name: llama-bin-macos-x64.zip - ubuntu-latest-cmake: - runs-on: ubuntu-latest + ubuntu-cpu-cmake: + runs-on: ubuntu-22.04 steps: - name: Clone @@ -181,7 +183,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=ON \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -256,7 +261,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -265,7 +273,11 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Test @@ -295,7 +307,8 @@ jobs: run: | mkdir build cd build - cmake -DGGML_RPC=ON .. + cmake .. \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -325,7 +338,8 @@ jobs: run: | mkdir build cd build - cmake -DGGML_VULKAN=ON .. + cmake .. \ + -DGGML_VULKAN=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -352,13 +366,18 @@ jobs: - name: Build with native CMake HIP support id: cmake_build run: | - cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIP=ON + cmake -B build -S . \ + -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DGGML_HIP=ON cmake --build build --config Release -j $(nproc) - name: Build with legacy HIP support id: cmake_build_legacy_hip run: | - cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIP=ON + cmake -B build2 -S . \ + -DCMAKE_C_COMPILER=hipcc \ + -DCMAKE_CXX_COMPILER=hipcc \ + -DGGML_HIP=ON cmake --build build2 --config Release -j $(nproc) ubuntu-22-cmake-musa: @@ -379,7 +398,8 @@ jobs: - name: Build with native CMake MUSA support id: cmake_build run: | - cmake -B build -S . -DGGML_MUSA=ON + cmake -B build -S . \ + -DGGML_MUSA=ON cmake --build build --config Release -j $(nproc) ubuntu-22-cmake-sycl: @@ -420,7 +440,10 @@ jobs: source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx cmake --build . --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: @@ -461,42 +484,13 @@ jobs: source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DGGML_SYCL_F16=ON cmake --build . --config Release -j $(nproc) - # TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know - # how to debug it. - # ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584 - # would be great if we fix these - macOS-latest-cmake: - runs-on: macos-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF .. - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - - - name: Test - id: cmake_test - run: | - cd build - ctest -L main --verbose --timeout 900 - macOS-latest-cmake-ios: runs-on: macos-latest @@ -827,7 +821,13 @@ jobs: - name: Build with CMake run: | - cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON + cmake -S . -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES=89-real \ + -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CUDA=ON cmake --build build windows-2019-cmake-cuda: @@ -916,7 +916,11 @@ jobs: shell: cmd run: | call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON + cmake -S . -B build -G "Ninja Multi-Config" \ + -DLLAMA_BUILD_SERVER=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CUDA=ON \ + -DGGML_RPC=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release @@ -1201,8 +1205,7 @@ jobs: runs-on: ubuntu-latest needs: - - ubuntu-latest-cmake - - macOS-latest-cmake + - ubuntu-cpu-cmake - windows-latest-cmake - windows-2019-cmake-cuda - windows-latest-cmake-hip-release @@ -1461,3 +1464,37 @@ jobs: # popd # emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} # make + + openEuler-latest-cmake-cann: + if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} + defaults: + run: + shell: bash -el {0} + runs-on: ubuntu-24.04-arm + strategy: + matrix: + cann: + - '8.0.rc3.beta1-910b-openeuler22.03-py3.10' + device: + - 'ascend910b3' + build: + - 'Release' + container: ascendai/cann:${{ matrix.cann }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Dependencies + run: | + yum update -y + yum install -y git gcc gcc-c++ make cmake + + - name: Build + run: | + export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} + + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_CANN=on \ + -DSOC_TYPE=${{ matrix.device }} + cmake --build build -j $(nproc) diff --git a/CMakeLists.txt b/CMakeLists.txt index 42caed486..7e41a44d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(LLAMA_STANDALONE ON) diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 8d8312e91..89ddbd669 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -133,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets. ### Build image ```sh # Using FP16 -docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile . +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . ``` *Notes*: diff --git a/docs/build.md b/docs/build.md index 3b0d2211d..dd6495028 100644 --- a/docs/build.md +++ b/docs/build.md @@ -286,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container. ```sh # Build the image -docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile . +docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile . # Then, use it: docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 diff --git a/docs/docker.md b/docs/docker.md index 8d90e6ded..dac9a9ec1 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -60,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia ## Building Docker locally ```bash -docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile . -docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile . -docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile . +docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile . ``` You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture. @@ -95,9 +95,9 @@ Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/ ## Building Docker locally ```bash -docker build -t local/llama.cpp:full-musa -f .devops/full-musa.Dockerfile . -docker build -t local/llama.cpp:light-musa -f .devops/llama-cli-musa.Dockerfile . -docker build -t local/llama.cpp:server-musa -f .devops/llama-server-musa.Dockerfile . +docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile . ``` You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture. diff --git a/examples/run/README.md b/examples/run/README.md index a06805441..89a552079 100644 --- a/examples/run/README.md +++ b/examples/run/README.md @@ -3,11 +3,10 @@ The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. ```bash -llama-run granite-code +llama-run granite3-moe ``` ```bash -llama-run -h Description: Runs a llm @@ -17,7 +16,7 @@ Usage: Options: -c, --context-size Context size (default: 2048) - -n, --ngl + -n, -ngl, --ngl Number of GPU layers (default: 0) --temp Temperature (default: 0.8) diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 26f3583bd..582ccc0d3 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html index 2180ef4ad..d3893ea4e 100644 --- a/examples/server/webui/index.html +++ b/examples/server/webui/index.html @@ -141,6 +141,7 @@ :msg="pendingMsg" :key="pendingMsg.id" :is-generating="isGenerating" + :show-thought-in-progress="config.showThoughtInProgress" :edit-user-msg-and-regenerate="() => {}" :regenerate-msg="() => {}"> @@ -202,6 +203,20 @@ + +
+ Reasoning models +
+
+ + Expand though process by default for generating message +
+
+ + Exclude thought process when sending request to API (Recommended for DeepSeek-R1) +
+
+
Advanced config @@ -261,7 +276,17 @@
- +
+ + + + Thinking + + Thought Process + + +
+
`; } + +/** + * filter out redundant fields upon sending to API + * @param {Array} messages + * @returns {Array} + */ +function normalizeMsgsForAPI(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.content, + }; + }); +} + +/** + * recommended for DeepsSeek-R1, filter out content between and tags + * @param {Array} messages + * @returns {Array} + */ +function filterThoughtFromMsgs(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.role === 'assistant' + ? msg.content.split('').at(-1).trim() + : msg.content, + }; + }); +} diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 185079aa4..123c755ac 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -58,7 +58,8 @@ else() set(GGML_BLAS_VENDOR_DEFAULT "Generic") endif() -if (CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") set(GGML_NATIVE_DEFAULT OFF) else() set(GGML_NATIVE_DEFAULT ON) @@ -153,6 +154,7 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0ed92b3ff..9e627da8c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); } @@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 35a1c876c..2ccb4b472 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c7b6be4e2..ce4b9cfb5 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s template static __global__ void k_repeat_back( - const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2) { + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { - const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; - const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; - const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; if (tid0 >= ne0) { return; } T sum = 0; - for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { - for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { - for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { - sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } } } } - dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } template @@ -274,12 +279,14 @@ struct bin_bcast_cuda { template static void repeat_back_cuda( - const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); - k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); } template @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_can_repeat(dst, src0)); cudaStream_t stream = ctx.stream(); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - GGML_ASSERT(src0->ne[3] == 1); + GGML_TENSOR_UNARY_OP_LOCALS; - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(ne2*ne3 <= (1 << 15)); + + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; switch (dst->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; - repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); } break; default: { GGML_ASSERT(false); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2c0a56226..a79fa83c5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -588,7 +588,7 @@ struct ggml_tensor_extra_gpu { }; -#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7fd1fc853..a53a1bbd0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -62,7 +62,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails - cudaGetDevice(&id); + (void)cudaGetDevice(&id); GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); @@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -309,6 +309,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { size_t pool_used = 0; size_t pool_size = 0; size_t granularity; +#if defined(GGML_USE_HIP) + std::vector> mappings; +#endif explicit ggml_cuda_pool_vmm(int device) : device(device), @@ -317,7 +320,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { ~ggml_cuda_pool_vmm() { if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); } } @@ -350,7 +360,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } // map at the end of the pool - CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0)); + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif // the memory allocation handle is no longer needed after mapping CU_CHECK(cuMemRelease(handle)); @@ -360,7 +374,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); // add to the pool pool_size += reserve_size; @@ -372,7 +386,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_addr != 0); - void * ptr = (void *) (pool_addr + pool_used); + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); *actual_size = size; pool_used += size; @@ -391,17 +405,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -547,7 +561,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -962,7 +976,7 @@ static void * ggml_cuda_host_malloc(size_t size) { cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; @@ -1082,7 +1096,9 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1103,28 +1119,38 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + if (compute_capability == GGML_CUDA_CC_CDNA) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); @@ -1197,7 +1223,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } else { cudaError_t err = cudaDeviceDisablePeerAccess(id_other); @@ -1205,7 +1231,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } } @@ -1613,10 +1639,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cudaDataType_t cu_data_type = CUDA_R_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; @@ -1645,6 +1667,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } + if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -2438,7 +2466,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto if (stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); + (void)cudaGetLastError(); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -2493,14 +2521,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; +#ifdef __HIP_PLATFORM_AMD__ + hipGraphNode_t errorNode; + hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#else cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#endif if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); #endif + // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate - cudaGetLastError(); + (void)cudaGetLastError(); CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); cuda_ctx->cuda_graph->instance = nullptr; CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); @@ -2728,7 +2762,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); @@ -2748,7 +2782,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { cudaError_t err = cudaHostUnregister(buffer); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); } } @@ -3002,7 +3036,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_REPEAT_BACK: - return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e3b912d87..4fb466ca0 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA + if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: nwarps = 4; @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda( break; } } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 73e3e2c47..c9b2b699c 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); @@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK( cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index c905b15d7..8594093f0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -19,6 +19,12 @@ #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -74,6 +80,21 @@ #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget @@ -81,6 +102,28 @@ #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap() do { abort(); __builtin_unreachable(); } while(0) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index d090ba9bd..77994a698 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() +if (GGML_HIP_GRAPHS) + add_compile_definitions(GGML_HIP_GRAPHS) +endif() + +if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b1d0d4913..92c4294c5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5339,7 +5339,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MUL: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } if (src1_needs_grads) { struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); @@ -5431,21 +5431,25 @@ static void ggml_compute_backward( // src1.shape [n,p,qq,rr] if (src0_needs_grads) { - struct ggml_tensor * s1_tg = + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, @@ -5514,7 +5518,9 @@ static void ggml_compute_backward( if (src0_needs_grads) { GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); GGML_ASSERT(ggml_is_contiguous(grad)); - ggml_add_or_set(ctx, cgraph, isrc0, grad); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); } } break; case GGML_OP_RESHAPE: { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 381956a04..468016403 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1302,6 +1302,59 @@ struct test_repeat : public test_case { } }; +// GGML_OP_REPEAT_BACK +struct test_repeat_back : public test_case { + const ggml_type type; + const std::array ne; + const std::array nr; + const bool v; // whether src is a noncontiguous view + + std::string vars() override { + return VARS_TO_STR4(type, ne, nr, v); + } + + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 2; + } + + test_repeat_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {8, 6, 4, 2}, + std::array nr = {2, 2, 2, 2}, + bool v = false) + : type(type), ne(ne), nr(nr), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(src, "src"); + + if (v) { + GGML_ASSERT(ne[0] % 2 == 0); + GGML_ASSERT(ne[1] % 2 == 0); + GGML_ASSERT(ne[2] % 2 == 0); + GGML_ASSERT(ne[3] % 2 == 0); + GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1); + GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1); + GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1); + GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1); + + const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2; + const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2; + const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2; + const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2; + + src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0); + } + + ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(target, "target"); + + ggml_tensor * out = ggml_repeat_back(ctx, src, target); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_DUP struct test_dup : public test_case { const ggml_type type; @@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case { return 5e-4; } + int64_t grad_nmax() override { + return 20000; + } + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; @@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case { a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case { } else { a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); } @@ -3798,6 +3863,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); } + for (bool view : {false, true}) { + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + } + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); @@ -3919,21 +3994,25 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { // test cases without permutation - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2})); // test cases with permutation test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));