diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile index 8d020f16c..522ee8147 100644 --- a/.devops/cpu.Dockerfile +++ b/.devops/cpu.Dockerfile @@ -2,6 +2,10 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:$UBUNTU_VERSION AS build +ARG TARGETARCH + +ARG GGML_CPU_ARM_ARCH=armv8-a + RUN apt-get update && \ apt-get install -y build-essential git cmake libcurl4-openssl-dev @@ -9,7 +13,14 @@ WORKDIR /app COPY . . -RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \ +RUN if [ "$TARGETARCH" = "amd64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \ + else \ + echo "Unsupported architecture"; \ + exit 1; \ + fi && \ cmake --build build -j $(nproc) RUN mkdir -p /app/lib && \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe3b2cdfa..37cb6b1e7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -56,6 +56,7 @@ jobs: mkdir build cd build cmake .. \ + -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL_USE_BF16=ON \ @@ -120,6 +121,7 @@ jobs: # Metal is disabled due to intermittent failures with Github runners not having a GPU: # https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL=OFF \ @@ -160,8 +162,8 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip name: llama-bin-macos-x64.zip - ubuntu-latest-cmake: - runs-on: ubuntu-latest + ubuntu-cpu-cmake: + runs-on: ubuntu-22.04 steps: - name: Clone @@ -181,7 +183,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=ON \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -256,7 +261,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -265,7 +273,11 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Test @@ -295,7 +307,8 @@ jobs: run: | mkdir build cd build - cmake -DGGML_RPC=ON .. + cmake .. \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -325,7 +338,8 @@ jobs: run: | mkdir build cd build - cmake -DGGML_VULKAN=ON .. + cmake .. \ + -DGGML_VULKAN=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -352,13 +366,18 @@ jobs: - name: Build with native CMake HIP support id: cmake_build run: | - cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIP=ON + cmake -B build -S . \ + -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DGGML_HIP=ON cmake --build build --config Release -j $(nproc) - name: Build with legacy HIP support id: cmake_build_legacy_hip run: | - cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIP=ON + cmake -B build2 -S . \ + -DCMAKE_C_COMPILER=hipcc \ + -DCMAKE_CXX_COMPILER=hipcc \ + -DGGML_HIP=ON cmake --build build2 --config Release -j $(nproc) ubuntu-22-cmake-musa: @@ -379,7 +398,8 @@ jobs: - name: Build with native CMake MUSA support id: cmake_build run: | - cmake -B build -S . -DGGML_MUSA=ON + cmake -B build -S . \ + -DGGML_MUSA=ON cmake --build build --config Release -j $(nproc) ubuntu-22-cmake-sycl: @@ -420,7 +440,10 @@ jobs: source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx cmake --build . --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: @@ -461,42 +484,13 @@ jobs: source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DGGML_SYCL_F16=ON cmake --build . --config Release -j $(nproc) - # TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know - # how to debug it. - # ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584 - # would be great if we fix these - macOS-latest-cmake: - runs-on: macos-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF .. - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - - - name: Test - id: cmake_test - run: | - cd build - ctest -L main --verbose --timeout 900 - macOS-latest-cmake-ios: runs-on: macos-latest @@ -827,7 +821,13 @@ jobs: - name: Build with CMake run: | - cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON + cmake -S . -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES=89-real \ + -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CUDA=ON cmake --build build windows-2019-cmake-cuda: @@ -916,7 +916,11 @@ jobs: shell: cmd run: | call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DLLAMA_BUILD_SERVER=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CUDA=ON ^ + -DGGML_RPC=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release @@ -1069,7 +1073,12 @@ jobs: run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DGGML_RPC=ON + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_BUILD_TYPE=Release ` + -DGGML_HIP=ON ` + -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} windows-latest-cmake-hip-release: @@ -1107,7 +1116,13 @@ jobs: run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_BUILD_TYPE=Release ` + -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` + -DGGML_HIP=ON ` + -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} md "build\bin\rocblas\library\" cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" @@ -1201,8 +1216,7 @@ jobs: runs-on: ubuntu-latest needs: - - ubuntu-latest-cmake - - macOS-latest-cmake + - ubuntu-cpu-cmake - windows-latest-cmake - windows-2019-cmake-cuda - windows-latest-cmake-hip-release @@ -1461,3 +1475,37 @@ jobs: # popd # emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} # make + + openEuler-latest-cmake-cann: + if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} + defaults: + run: + shell: bash -el {0} + runs-on: ubuntu-24.04-arm + strategy: + matrix: + cann: + - '8.0.rc3.beta1-910b-openeuler22.03-py3.10' + device: + - 'ascend910b3' + build: + - 'Release' + container: ascendai/cann:${{ matrix.cann }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Dependencies + run: | + yum update -y + yum install -y git gcc gcc-c++ make cmake + + - name: Build + run: | + export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} + + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_CANN=on \ + -DSOC_TYPE=${{ matrix.device }} + cmake --build build -j $(nproc) diff --git a/CMakeLists.txt b/CMakeLists.txt index 42caed486..e7f520582 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(LLAMA_STANDALONE ON) @@ -49,6 +50,7 @@ endif() if (MSVC) add_compile_options("$<$:/utf-8>") add_compile_options("$<$:/utf-8>") + add_compile_options(/bigobj) endif() # diff --git a/Makefile b/Makefile index 19ae0d5f1..295522ba3 100644 --- a/Makefile +++ b/Makefile @@ -1361,7 +1361,9 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ + common/chat-template.hpp \ common/json.hpp \ + common/minja.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/README.md b/README.md index 784669ce1..97d028670 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123 +- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode +- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim +- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123 - Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669 - Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake index ea3dc55c8..c1a456e17 100644 --- a/cmake/build-info.cmake +++ b/cmake/build-info.cmake @@ -44,7 +44,7 @@ if(MSVC) set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) else() execute_process( - COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER} + COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER} OUTPUT_VARIABLE OUT OUTPUT_STRIP_TRAILING_WHITESPACE ) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index df1cdf9a5..24b7f8741 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-template.hpp common.cpp common.h console.cpp @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC json.hpp log.cpp log.h + minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index dede335fb..a6226a34b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -133,7 +133,8 @@ static void common_params_handle_model_default( const std::string & model_url, std::string & hf_repo, std::string & hf_file, - const std::string & hf_token) { + const std::string & hf_token, + const std::string & model_default) { if (!hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (hf_file.empty()) { @@ -163,7 +164,7 @@ static void common_params_handle_model_default( model = fs_get_cache_file(string_split(f, '/').back()); } } else if (model.empty()) { - model = DEFAULT_MODEL_PATH; + model = model_default; } } @@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); + common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, ""); if (params.escape) { string_process_escapes(params.prompt); @@ -323,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); } + if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + params.chat_template.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates" + )); + } + return true; } @@ -1629,6 +1639,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); + add_opt(common_arg( + {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", + "Same as --hf-repo, but for the draft model (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.hf_repo = value; + } + ).set_env("LLAMA_ARG_HFD_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", @@ -1938,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](common_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - if (!common_chat_verify_template(value)) { - throw std::runtime_error(string_format( - "error: the supplied chat template is not supported: %s\n" - "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", - value.c_str() - )); - } params.chat_template = value; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(common_arg( + {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", + string_format( + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.chat_template)); + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/chat-template.hpp b/common/chat-template.hpp new file mode 100644 index 000000000..42ee0b615 --- /dev/null +++ b/common/chat-template.hpp @@ -0,0 +1,268 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "minja.hpp" +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class chat_template { + public: + + private: + bool supports_tools_ = true; + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments_ = false; + bool requires_typed_content_ = false; + bool supports_system_role_ = true; + bool supports_parallel_tool_calls_ = false; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; + + std::string try_raw_render( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); + return prompt; + } catch (const std::exception & e) { + // fprintf(stderr, "Error: %s\n", e.what()); + return ""; + } + } + + public: + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) + { + template_root_ = minja::Parser::parse(source_, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + supports_tools_ = source.find("tools") != std::string::npos; + + auto renders_string_arguments = + try_raw_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", "{\"code\": \"print('Hello, World!')\"}"}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + if (!renders_string_arguments) { + auto renders_object_arguments = + try_raw_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", { + {"code", "print('Hello, World!')"}, + }}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + requires_object_arguments_ = renders_object_arguments; + } + supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; + + supports_system_role_ = try_raw_render({ + {{"role", "system"}, {"content", ""}}, + {{"role", "user"}, {"content", "Hey"}} + }, {}, false).find("") != std::string::npos; + + requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos + && try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos; + } + + const std::string & source() const { return source_; } + const std::string & bos_token() const { return bos_token_; } + const std::string & eos_token() const { return eos_token_; } + bool supports_tools() const { return supports_tools_; } + bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } + + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool adjust_inputs = true) const + { + json actual_messages; + + // First, "fix" messages so they have a chance to be rendered correctly by the template + + if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) { + actual_messages = json::array(); + + auto add_message = [&](const json & msg) { + if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + actual_messages.push_back({ + {"role", msg.at("role")}, + {"content", {{ + {"type", "text"}, + {"text", msg.at("content")}, + }}}, + }); + } else { + actual_messages.push_back(msg); + } + }; + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + add_message({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (const auto & message_ : messages) { + auto message = message_; + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + + if (message.contains("tool_calls")) { + if (requires_object_arguments_ || !supports_tools_) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } + } + } + if (!supports_tools_) { + auto content = message.at("content"); + auto tool_calls = json::array(); + for (const auto & tool_call : message.at("tool_calls")) { + if (tool_call.at("type") != "function") { + continue; + } + const auto & function = tool_call.at("function"); + auto tc = json { + {"name", function.at("name")}, + {"arguments", function.at("arguments")}, + }; + if (tool_call.contains("id")) { + tc["id"] = tool_call["id"]; + } + tool_calls.push_back(tc); + } + auto obj = json { + {"tool_calls", tool_calls}, + }; + if (!content.is_null() && content != "") { + obj["content"] = content; + } + message["content"] = obj.dump(2); + message.erase("tool_calls"); + } + } + if (!supports_tools_ && role == "tool") { + message["role"] = "user"; + auto obj = json { + {"tool_response", { + {"tool", message.at("name")}, + {"content", message.at("content")}, + }}, + }; + if (message.contains("tool_call_id")) { + obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); + } + message["content"] = obj.dump(2); + message.erase("name"); + } + + if (!message["content"].is_null() && !supports_system_role_) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + add_message(message); + } + flush_sys(); + } else { + actual_messages = messages; + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", bos_token_}, + {"eos_token", eos_token_}, + })); + + if (!tools.is_null()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + if (!extra_context.is_null()) { + for (auto & kv : extra_context.items()) { + minja::Value val(kv.value()); + context->set(kv.key(), val); + } + } + + return template_root_->render(context); + } +}; + +} // namespace minja diff --git a/common/common.cpp b/common/common.cpp index 451826d5d..6dea8e3d2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat-template.hpp" #include #include @@ -483,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } @@ -1728,67 +1771,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // -std::string common_get_builtin_chat_template(const struct llama_model * model) { - const char * ptr_tmpl = llama_model_chat_template(model); - return ptr_tmpl == nullptr ? "" : ptr_tmpl; -} - -bool common_chat_verify_template(const std::string & tmpl) { +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + auto chat_template = minja::chat_template(tmpl, "", ""); + chat_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } llama_chat_message chat[] = {{"user", "test"}}; const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const common_chat_template & tmpl, const std::vector & msgs, - bool add_ass) { + bool add_ass, + bool use_jinja) { + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + return tmpl.apply(messages, /* tools= */ json(), add_ass); + } + int alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; for (const auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { - if (ptr_tmpl != nullptr) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } - - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); } // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template( - fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); return formatted_chat; } -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const common_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass) { + bool add_ass, + bool use_jinja) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1796,21 +1847,74 @@ std::string common_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { +std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return common_chat_apply_template(model, tmpl, msgs, true); + return common_chat_apply_template(tmpl, msgs, true, use_jinja); +} + +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +{ + auto vocab = llama_model_get_vocab(model); + std::string default_template_src = chat_template_override; + std::string template_tool_use_src = chat_template_override; + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + auto str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} + )"; + } + } + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); + } + return std::string(); + } else { + return common_token_to_piece(vocab, token, true); + } + }; + auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + return { + has_explicit_template, + std::make_unique(default_template_src, token_bos, token_eos), + template_tool_use_src.empty() + ? nullptr + : std::make_unique(template_tool_use_src, token_bos, token_eos) + }; } // diff --git a/common/common.h b/common/common.h index 3bcc637cc..571260372 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,11 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string model = ""; // draft model for speculative decoding // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // draft model for speculative decoding // NOLINT + std::string model_url = ""; // model url to download // NOLINT }; struct common_params_vocoder { @@ -330,6 +334,7 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; std::vector api_keys; @@ -424,6 +429,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template @@ -508,12 +517,14 @@ struct llama_model * common_load_model_from_url( const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + struct llama_model * common_load_model_from_hf( const std::string & repo, const std::string & remote_path, const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + std::pair common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & hf_token); @@ -597,30 +608,43 @@ struct common_chat_msg { std::string content; }; -// Get the built-in chat template for the model. Return empty string if not present. -std::string common_get_builtin_chat_template(const struct llama_model * model); - // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl); +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +namespace minja { + class chat_template; +} + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const common_chat_template & tmpl, const std::vector & chat, - bool add_ass); + bool add_ass, + bool use_jinja); // Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const common_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass); + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); +std::string common_chat_format_example( + const common_chat_template & tmpl, bool use_jinja); + +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); // // KV cache utils diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dadc18c8b..4d426b6bd 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb); std::function _fetch_json; bool _dotall; std::map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -854,7 +809,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +860,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +974,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1036,10 +991,27 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); + return build_grammar([&](const llama_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb) { + SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); + llama_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 41623b346..4f43ab3a5 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,4 +5,12 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); + +struct llama_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +std::string build_grammar(const std::function & cb); diff --git a/common/minja.hpp b/common/minja.hpp new file mode 100644 index 000000000..80bdd4b41 --- /dev/null +++ b/common/minja.hpp @@ -0,0 +1,2812 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +inline std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + using CallableType = std::function &, ArgumentsValue &)>; + using FilterType = std::function &, ArgumentsValue &)>; + +private: + using ObjectType = nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { + if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + + Value(const json & v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value pop(const Value& index) { + if (is_array()) { + if (array_->empty()) + throw std::runtime_error("pop from empty list"); + if (index.is_null()) { + auto ret = array_->back(); + array_->pop_back(); + return ret; + } else if (!index.is_number_integer()) { + throw std::runtime_error("pop index must be an integer: " + index.dump()); + } else { + auto i = index.get(); + if (i < 0 || i >= static_cast(array_->size())) + throw std::runtime_error("pop index out of range: " + index.dump()); + auto it = array_->begin() + (i < 0 ? array_->size() + i : i); + auto ret = *it; + array_->erase(it); + return ret; + } + } else if (is_object()) { + if (!index.is_hashable()) + throw std::runtime_error("Unashable type: " + index.dump()); + auto it = object_->find(index.primitive_); + if (it == object_->end()) + throw std::runtime_error("Key not found: " + index.dump()); + auto ret = it->second; + object_->erase(it); + return ret; + } else { + throw std::runtime_error("Value is not an array or object: " + dump()); + } + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) return Value(); + return it->second; + } + return Value(); + } + void set(const Value& key, const Value& value) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr & context, ArgumentsValue & args) const { + if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + void for_each(const std::function & callback) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto & item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.get()) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + throw std::runtime_error("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) return 0; + if (is_boolean()) return get() ? 1 : 0; + if (is_number()) return static_cast(get()); + if (is_string()) { + try { + return std::stol(get()); + } catch (const std::exception &) { + return 0; + } + } + return 0; + } + + bool operator<(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value & value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (!array_) throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) return array_->at(index); + if (is_object()) return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) res.push_back(item); + for (const auto& item : *rhs.array_) res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & [key, value] : kwargs) { + if (key == name) return value; + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } +}; + +template <> +inline json Value::get() const { + if (is_primitive()) return primitive_; + if (is_null()) return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& [key, value] : *object_) { + if (key.is_string()) { + res[key.get()] = value.get(); + } else if (key.is_primitive()) { + res[key.dump()] = value.get(); + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + dump()); +} + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + using Parameters = std::vector>>; + + Location location; + + Expression(const Location & location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + try { + return do_evaluate(context); + } catch (const std::exception & e) { + std::ostringstream out; + out << e.what(); + if (location.source) out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + case Type::Filter: return "filter"; + case Type::EndFilter: return "endfilter"; + case Type::Generation: return "generation"; + case Type::EndGeneration: return "endgeneration"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) + : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct GenerationTemplateToken : public TemplateToken { + GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} +}; + +struct EndGenerationTemplateToken : public TemplateToken { + EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +class TemplateNode { + Location location_; +protected: + virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + void render(std::ostringstream & out, const std::shared_ptr & context) const { + try { + do_render(out, context); + } catch (const std::exception & e) { + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & location, std::vector> && c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + void do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; +public: + ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector, std::shared_ptr>> cascade; +public: + IfNode(const Location & location, std::vector, std::shared_ptr>> && c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); + branch.second->render(out, context); + return; + } + } + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; +public: + ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) + : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) throw std::runtime_error("ForNode.iterable is null"); + if (!body) throw std::runtime_error("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); + } + iterable_value.for_each([&](Value & item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + body->render(out, loop_context); + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) throw std::runtime_error("MacroNode.name is null"); + if (!body) throw std::runtime_error("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto & [arg_name, value] : args.kwargs) { + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + +public: + FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!filter) throw std::runtime_error("FilterNode.filter is null"); + if (!body) throw std::runtime_error("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; +public: + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) throw std::runtime_error("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; +public: + SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & location, std::vector> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) throw std::runtime_error("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::shared_ptr>> elements; +public: + DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& [key, value] : elements) { + if (!key) throw std::runtime_error("Dict key is null"); + if (!value) throw std::runtime_error("Dict value is null"); + result.set(key->evaluate(context), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::shared_ptr start, end; + SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr &) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; +public: + SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!base) throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) throw std::runtime_error("SubscriptExpr.index is null"); + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + throw std::runtime_error("Expansion operator is only supported in function calls and collections"); + + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::shared_ptr left; + std::shared_ptr right; + Op op; +public: + BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_iterable(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return l; + return right->evaluate(context); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr & context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (auto un_expr = std::dynamic_pointer_cast(arg)) { + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + throw std::runtime_error("Expansion operator only supported on arrays"); + } + array.for_each([&](Value & value) { + vargs.args.push_back(value); + }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + throw std::runtime_error("ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value & key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& [name, value] : this->kwargs) { + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string & s) { + auto start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return ""; + auto end = s.find_last_not_of(" \t\n\r"); + return s.substr(start, end - start + 1); +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; +public: + MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) throw std::runtime_error("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {1, 1}, {0, 0}); + return obj.pop(vargs.args[0]); + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); + else res[i] = std::tolower(res[i]); + } + return res; + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { +public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("CallExpr.object is null"); + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & location, std::vector> && p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) throw std::runtime_error("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { +private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return std::make_unique(std::move(result)); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return std::make_shared(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto [condition, else_expr] = parseIfExpression(); + return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::shared_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) throw std::runtime_error("Expected 'else' expression"); + } + return std::pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'or' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) throw std::runtime_error("Expected right side of 'and' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else throw std::runtime_error("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) return expr; + if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression"); + return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({ "]" })) { + index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::shared_ptr>> elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken & token) const { + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken & token) const { + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); + static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); + static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + std::smatch match; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_unique(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) throw std::runtime_error("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "generation") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endgeneration") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) throw std::runtime_error("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (std::regex_search(it, end, match, non_text_open_regex)) { + auto text_end = it + match.position(); + text = std::string(it, text_end); + it = text_end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + text = std::string(it, end); + it = end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } + } + return tokens; + } catch (const std::exception & e) { + throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, std::shared_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(cascade))); + } else if (auto for_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { + throw unterminated(**start); + } + // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). + children.emplace_back(std::move(body)); + } else if (auto text_token = dynamic_cast(token.get())) { + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + auto i = text.size(); + while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; + if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { + text.resize(i); + } + } + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { + if (text.length() > 0 && text[0] == '\n') { + text.erase(0, 1); + } + } + if (it == end && !options.keep_trailing_newline) { + auto i = text.size(); + if (i > 0 && text[i - 1] == '\n') { + i--; + if (i > 0 && text[i - 1] == '\r') i--; + text.resize(i); + } + } + children.emplace_back(std::make_shared(token->location, text)); + } else if (auto expr_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); + } else if (auto set_token = dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); + } + } else if (auto macro_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto filter_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); + } else if (dynamic_cast(token.get())) { + // Ignore comments + } else if (dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it-1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return std::make_shared(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), std::move(children)); + } + } + +public: + + static std::shared_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(normalize_newlines(template_str)), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (auto & [name, value] : args.kwargs) { + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + name + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (obj.is_string()) { + auto json_obj = json::parse(obj.get()); + for (const auto & kv : json_obj.items()) { + items.push_back(Value::array({kv.key(), kv.value()})); + } + } else if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not a list"); + if (items.size() == 0) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); + return Value(res); + })); + globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto & value = args.args[0]; + auto & default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); + for (auto & [name, value] : args.kwargs) { + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject + globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (!pred_res.to_bool()) { + res.push_back(item); + } + } + return res; + })); + globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + ArgumentsValue filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent(args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) is_first = false; + else out += "\n"; + if (needs_indent) out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') out += "\n"; + return out; + })); + globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool()) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + })); + globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") i = 0; + else if (name == "end") i = 1; + else if (name == "step") i = 2; + else throw std::runtime_error("Unknown argument " + name + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 95f112043..63b54a9cf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -696,6 +696,9 @@ class Model: if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 res = "deepseek-v3" + if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + res = "deepseek-r1-qwen" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 56edc64a7..cea34413f 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -65,49 +65,50 @@ else: # TODO: add models here, base models preferred models = [ - {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, - {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, - {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, - {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, - {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, - {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, - {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, - {"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", }, - {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, - {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, - {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, - {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, - {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, - {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, - {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, - {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, - {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, - {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, - {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, - {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! - {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, - {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, - {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, - {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, - {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, - {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B - {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, - {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, - {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, - {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, - {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, - {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, - {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, - {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, - {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, - {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, - {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, - {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, - {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, - {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, - {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, - {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, - {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, + {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, + {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, + {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, + {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, + {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, + {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, + {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, + {"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", }, + {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, + {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, + {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, + {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, + {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, + {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, + {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, + {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, + {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, + {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! + {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, + {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, + {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, + {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, + {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, + {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B + {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, + {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, + {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, + {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, + {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, + {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, + {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, + {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, + {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, + {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, + {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, + {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, + {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, + {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, ] diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 8d8312e91..89ddbd669 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -133,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets. ### Build image ```sh # Using FP16 -docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile . +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . ``` *Notes*: diff --git a/docs/build.md b/docs/build.md index 3b0d2211d..dd6495028 100644 --- a/docs/build.md +++ b/docs/build.md @@ -286,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container. ```sh # Build the image -docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile . +docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile . # Then, use it: docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 diff --git a/docs/docker.md b/docs/docker.md index 8d90e6ded..dac9a9ec1 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -60,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia ## Building Docker locally ```bash -docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile . -docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile . -docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile . +docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile . ``` You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture. @@ -95,9 +95,9 @@ Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/ ## Building Docker locally ```bash -docker build -t local/llama.cpp:full-musa -f .devops/full-musa.Dockerfile . -docker build -t local/llama.cpp:light-musa -f .devops/llama-cli-musa.Dockerfile . -docker build -t local/llama.cpp:server-musa -f .devops/llama-server-musa.Dockerfile . +docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile . ``` You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture. diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 99063b5d5..91238e4be 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -345,8 +345,18 @@ struct lora_merge_ctx { gf = ggml_new_graph(ctx0); struct ggml_tensor * cur = inp_base; for (size_t i = 0; i < adapters.size(); ++i) { - struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))); - struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32)); + struct ggml_tensor * delta; + bool is_tok_embd = string_starts_with(name_base, "token_embd"); + if (is_tok_embd) { + printf("%s : detected token embeddings tensor\n", __func__); + delta = ggml_mul_mat(ctx0, + ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32), + ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)); + } else { + delta = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))), + ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32)); + } // scale const float alpha = adapters[i]->alpha; const float rank = (float) inp_b[i]->ne[0]; diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md new file mode 100644 index 000000000..8713a43d6 --- /dev/null +++ b/examples/llava/README-minicpmo2.6.md @@ -0,0 +1,46 @@ +## MiniCPM-o 2.6 +Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. + +### Prepare models and code + +Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpm-omni +``` + +### Usage of MiniCPM-o 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build llama.cpp using `CMake`: +https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md + +```bash +cmake -B build +cmake --build build --config Release +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7a8a3156b..24073c5a9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -2041,6 +2049,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } return true; } else if (ctx->has_qwen2vl_merger) { @@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); @@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index c598caf3d..2cac7933d 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } else { - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); - } + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); } if (!encoded) { @@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 38c44e130..53d902d61 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e else if (has_minicpmv_projector == 3) { system_prompt = "<|im_start|>user\n"; } + else if (has_minicpmv_projector == 4) { + system_prompt = "<|im_start|>user\n"; + } LOG_INF("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); @@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { user_prompt = "<|im_start|>user\n" + prompt; } + else if (has_minicpmv_projector == 4) { + user_prompt = "<|im_start|>user\n" + prompt; + } } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); @@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); } + else if (has_minicpmv_projector == 4) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } // generate the response @@ -308,7 +317,6 @@ int main(int argc, char ** argv) { const auto * tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior printf("%s", tmp);// mistral llava-1.6 if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index ea773742a..9b196757f 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) # with proper args = ap.parse_args() @@ -545,12 +545,19 @@ if args.use_f32: minicpmv_version = args.minicpmv_version emb_dim = 4096 +block_count = 26 if minicpmv_version == 1: emb_dim = 2304 + block_count = 26 elif minicpmv_version == 2: emb_dim = 4096 + block_count = 27 elif minicpmv_version == 3: emb_dim = 3584 + block_count = 27 +elif minicpmv_version == 4: + emb_dim = 3584 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config) if minicpmv_version == 3: vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 4: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None: fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True - minicpmv_version = 3 + minicpmv_version = 4 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -625,7 +635,6 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", 0) fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) if processor is not None: diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 748ff5c57..ba8211658 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) checkpoint = model.state_dict() # get a list of mm tensor names diff --git a/examples/main/README.md b/examples/main/README.md index 17d80a622..46f92eb7a 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models. ### Batch Size -- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations. +- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`. -- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`. +- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`. ### Prompt Caching diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39666a0e8..da2a03ab9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,6 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" +#include "chat-template.hpp" #include #include @@ -84,14 +85,6 @@ static void sigint_handler(int signo) { } #endif -static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); - chat_msgs.push_back({role, content}); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; -} - int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -165,6 +158,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); + auto chat_templates = common_chat_templates_from_model(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -207,7 +201,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty(); + const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -225,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -269,10 +263,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + common_chat_msg new_msg{role, content}; + auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back({role, content}); + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + return formatted; + }; + { auto prompt = (params.conversation_mode && params.enable_chat_template) // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { @@ -779,7 +781,7 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str()); } is_interacting = true; LOG("\n"); @@ -844,7 +846,7 @@ int main(int argc, char ** argv) { bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)) : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); diff --git a/examples/run/README.md b/examples/run/README.md index a06805441..89a552079 100644 --- a/examples/run/README.md +++ b/examples/run/README.md @@ -3,11 +3,10 @@ The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. ```bash -llama-run granite-code +llama-run granite3-moe ``` ```bash -llama-run -h Description: Runs a llm @@ -17,7 +16,7 @@ Usage: Options: -c, --context-size Context size (default: 2048) - -n, --ngl + -n, -ngl, --ngl Number of GPU layers (default: 0) --temp Temperature (default: 0.8) diff --git a/examples/run/linenoise.cpp/linenoise.cpp b/examples/run/linenoise.cpp/linenoise.cpp index 050c23012..a68f12a1a 100644 --- a/examples/run/linenoise.cpp/linenoise.cpp +++ b/examples/run/linenoise.cpp/linenoise.cpp @@ -103,24 +103,26 @@ * */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "linenoise.h" +# include "linenoise.h" -#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -#define LINENOISE_MAX_LINE 4096 -static std::vector unsupported_term = {"dumb","cons25","emacs",nullptr}; +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include + +# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 +# define LINENOISE_MAX_LINE 4096 +static std::vector unsupported_term = { "dumb", "cons25", "emacs" }; static linenoiseCompletionCallback *completionCallback = NULL; static linenoiseHintsCallback *hintsCallback = NULL; static linenoiseFreeHintsCallback *freeHintsCallback = NULL; @@ -166,21 +168,58 @@ int linenoiseHistoryAdd(const char *line); #define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. static void refreshLine(struct linenoiseState *l); +class File { + public: + FILE * file = nullptr; + + FILE * open(const std::string & filename, const char * mode) { + file = fopen(filename.c_str(), mode); + + return file; + } + + int lock() { + if (file) { + fd = fileno(file); + if (flock(fd, LOCK_EX | LOCK_NB) != 0) { + fd = -1; + + return 1; + } + } + + return 0; + } + + ~File() { + if (fd >= 0) { + flock(fd, LOCK_UN); + } + + if (file) { + fclose(file); + } + } + + private: + int fd = -1; +}; + __attribute__((format(printf, 1, 2))) /* Debugging function. */ #if 0 static void lndebug(const char *fmt, ...) { - static FILE *lndebug_fp = NULL; - if (lndebug_fp == NULL) { - lndebug_fp = fopen("/tmp/lndebug.txt", "a"); + static File file; + if (file.file == nullptr) { + file.open("/tmp/lndebug.txt", "a"); } - if (lndebug_fp != NULL) { + if (file.file != nullptr) { va_list args; va_start(args, fmt); - vfprintf(lndebug_fp, fmt, args); + vfprintf(file.file, fmt, args); va_end(args); - fflush(lndebug_fp); + fflush(file.file); } } #else @@ -213,8 +252,11 @@ void linenoiseSetMultiLine(int ml) { static int isUnsupportedTerm(void) { char *term = getenv("TERM"); if (term == NULL) return 0; - for (int j = 0; unsupported_term[j]; ++j) - if (!strcasecmp(term, unsupported_term[j])) return 1; + for (size_t j = 0; j < unsupported_term.size(); ++j) { + if (!strcasecmp(term, unsupported_term[j])) { + return 1; + } + } return 0; } @@ -334,17 +376,6 @@ static void linenoiseBeep(void) { fflush(stderr); } -/* ============================== Completion ================================ */ - -/* Free a list of completion option populated by linenoiseAddCompletion(). */ -static void freeCompletions(linenoiseCompletions *lc) { - size_t i; - for (i = 0; i < lc->len; i++) - free(lc->cvec[i]); - if (lc->cvec != NULL) - free(lc->cvec); -} - /* Called by completeLine() and linenoiseShow() to render the current * edited line with the proposed completion. If the current completion table * is already available, it is passed as second argument, otherwise the @@ -353,9 +384,9 @@ static void freeCompletions(linenoiseCompletions *lc) { * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { /* Obtain the table of completions if the caller didn't provide one. */ - linenoiseCompletions ctable = { 0, NULL }; + linenoiseCompletions ctable; if (lc == NULL) { - completionCallback(ls->buf,&ctable); + completionCallback(ls->buf, &ctable); lc = &ctable; } @@ -364,16 +395,17 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple struct linenoiseState saved = *ls; ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); ls->buf = lc->cvec[ls->completion_idx]; - refreshLineWithFlags(ls,flags); + refreshLineWithFlags(ls, flags); ls->len = saved.len; ls->pos = saved.pos; ls->buf = saved.buf; } else { - refreshLineWithFlags(ls,flags); + refreshLineWithFlags(ls, flags); } - /* Free the completions table if needed. */ - if (lc != &ctable) freeCompletions(&ctable); + if (lc == &ctable) { + ctable.to_free = false; + } } /* This is an helper function for linenoiseEdit*() and is called when the @@ -391,11 +423,11 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple * possible completions, and the caller should read for the next characters * from stdin. */ static int completeLine(struct linenoiseState *ls, int keypressed) { - linenoiseCompletions lc = { 0, NULL }; + linenoiseCompletions lc; int nwritten; char c = keypressed; - completionCallback(ls->buf,&lc); + completionCallback(ls->buf, &lc); if (lc.len == 0) { linenoiseBeep(); ls->in_completion = 0; @@ -406,7 +438,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) { ls->in_completion = 1; ls->completion_idx = 0; } else { - ls->completion_idx = (ls->completion_idx+1) % (lc.len+1); + ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1); if (ls->completion_idx == lc.len) linenoiseBeep(); } c = 0; @@ -420,8 +452,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) { default: /* Update buffer and return */ if (ls->completion_idx < lc.len) { - nwritten = snprintf(ls->buf,ls->buflen,"%s", - lc.cvec[ls->completion_idx]); + nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]); ls->len = ls->pos = nwritten; } ls->in_completion = 0; @@ -430,13 +461,12 @@ static int completeLine(struct linenoiseState *ls, int keypressed) { /* Show completion or original buffer */ if (ls->in_completion && ls->completion_idx < lc.len) { - refreshLineWithCompletion(ls,&lc,REFRESH_ALL); + refreshLineWithCompletion(ls, &lc, REFRESH_ALL); } else { refreshLine(ls); } } - freeCompletions(&lc); return c; /* Return last read character */ } @@ -462,53 +492,25 @@ void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { * user typed . See the example.c source code for a very easy to * understand example. */ void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { - size_t len = strlen(str); - char *copy, **cvec; - - copy = (char*) malloc(len + 1); - if (copy == NULL) return; - memcpy(copy,str,len+1); - cvec = (char**) realloc(lc->cvec,sizeof(char*)*(lc->len+1)); - if (cvec == NULL) { - free(copy); + const size_t len = strlen(str); + auto copy = std::make_unique(len + 1); + if (!copy) { return; } + + memcpy(copy.get(), str, len + 1); + char ** cvec = static_cast(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1))); + if (cvec == nullptr) { + return; + } + lc->cvec = cvec; - lc->cvec[lc->len++] = copy; -} - -/* =========================== Line editing ================================= */ - -/* We define a very simple "append buffer" structure, that is an heap - * allocated string where we can append to. This is useful in order to - * write all the escape sequences in a buffer and flush them to the standard - * output in a single call, to avoid flickering effects. */ -struct abuf { - char *b; - int len; -}; - -static void abInit(struct abuf *ab) { - ab->b = NULL; - ab->len = 0; -} - -static void abAppend(struct abuf *ab, const char *s, int len) { - char *new_ptr = (char*) realloc(ab->b,ab->len+len); - - if (new_ptr == NULL) return; - memcpy(new_ptr+ab->len,s,len); - ab->b = new_ptr; - ab->len += len; -} - -static void abFree(struct abuf *ab) { - free(ab->b); + lc->cvec[lc->len++] = copy.release(); } /* Helper of refreshSingleLine() and refreshMultiLine() to show hints * to the right of the prompt. */ -static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int plen) { +static void refreshShowHints(std::string & ab, struct linenoiseState * l, int plen) { char seq[64]; if (hintsCallback && plen+l->len < l->cols) { int color = -1, bold = 0; @@ -522,10 +524,11 @@ static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int pl snprintf(seq,64,"\033[%d;%d;49m",bold,color); else seq[0] = '\0'; - abAppend(ab,seq,strlen(seq)); - abAppend(ab,hint,hintlen); + ab.append(seq); + ab.append(hint, hintlen); if (color != -1 || bold != 0) - abAppend(ab,"\033[0m",4); + ab.append("\033[0m"); + /* Call the function to free the hint returned. */ if (freeHintsCallback) freeHintsCallback(hint); } @@ -546,8 +549,7 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) { char *buf = l->buf; size_t len = l->len; size_t pos = l->pos; - struct abuf ab; - + std::string ab; while((plen+pos) >= l->cols) { buf++; len--; @@ -557,35 +559,34 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) { len--; } - abInit(&ab); /* Cursor to left edge */ snprintf(seq,sizeof(seq),"\r"); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); if (flags & REFRESH_WRITE) { /* Write the prompt and the current buffer content */ - abAppend(&ab,l->prompt,strlen(l->prompt)); + ab.append(l->prompt); if (maskmode == 1) { - while (len--) abAppend(&ab,"*",1); + while (len--) { + ab.append("*"); + } } else { - abAppend(&ab,buf,len); + ab.append(buf, len); } /* Show hits if any. */ - refreshShowHints(&ab,l,plen); + refreshShowHints(ab, l, plen); } /* Erase to right */ snprintf(seq,sizeof(seq),"\x1b[0K"); - abAppend(&ab,seq,strlen(seq)); - + ab.append(seq); if (flags & REFRESH_WRITE) { /* Move cursor to original position. */ snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen)); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } - if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ - abFree(&ab); + (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ } /* Multi line low level line refresh. @@ -604,26 +605,23 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) { int col; /* colum position, zero-based. */ int old_rows = l->oldrows; int fd = l->ofd, j; - struct abuf ab; - + std::string ab; l->oldrows = rows; /* First step: clear all the lines used before. To do so start by * going to the last row. */ - abInit(&ab); - if (flags & REFRESH_CLEAN) { if (old_rows-rpos > 0) { lndebug("go down %d", old_rows-rpos); snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } /* Now for every row clear it, go up. */ for (j = 0; j < old_rows-1; j++) { lndebug("clear+up"); snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } } @@ -631,21 +629,22 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) { /* Clean the top line. */ lndebug("clear"); snprintf(seq,64,"\r\x1b[0K"); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } if (flags & REFRESH_WRITE) { /* Write the prompt and the current buffer content */ - abAppend(&ab,l->prompt,strlen(l->prompt)); + ab.append(l->prompt); if (maskmode == 1) { - unsigned int i; - for (i = 0; i < l->len; i++) abAppend(&ab,"*",1); + for (unsigned int i = 0; i < l->len; ++i) { + ab.append("*"); + } } else { - abAppend(&ab,l->buf,l->len); + ab.append(l->buf, l->len); } /* Show hits if any. */ - refreshShowHints(&ab,l,plen); + refreshShowHints(ab, l, plen); /* If we are at the very end of the screen with our prompt, we need to * emit a newline and move the prompt to the first column. */ @@ -654,9 +653,9 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) { (l->pos+plen) % l->cols == 0) { lndebug(""); - abAppend(&ab,"\n",1); + ab.append("\n"); snprintf(seq,64,"\r"); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); rows++; if (rows > (int)l->oldrows) l->oldrows = rows; } @@ -669,7 +668,7 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) { if (rows-rpos2 > 0) { lndebug("go-up %d", rows-rpos2); snprintf(seq,64,"\x1b[%dA", rows-rpos2); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } /* Set column. */ @@ -679,14 +678,12 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) { snprintf(seq,64,"\r\x1b[%dC", col); else snprintf(seq,64,"\r"); - abAppend(&ab,seq,strlen(seq)); + ab.append(seq); } lndebug("\n"); l->oldpos = l->pos; - - if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ - abFree(&ab); + (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ } /* Calls the two low level functions refreshSingleLine() or @@ -1313,16 +1310,17 @@ int linenoiseHistorySetMaxLen(int len) { * otherwise -1 is returned. */ int linenoiseHistorySave(const char *filename) { mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); - FILE *fp; - int j; - - fp = fopen(filename,"w"); + File file; + file.open(filename, "w"); umask(old_umask); - if (fp == NULL) return -1; + if (file.file == NULL) { + return -1; + } chmod(filename,S_IRUSR|S_IWUSR); - for (j = 0; j < history_len; j++) - fprintf(fp,"%s\n",history[j]); - fclose(fp); + for (int j = 0; j < history_len; ++j) { + fprintf(file.file, "%s\n", history[j]); + } + return 0; } @@ -1332,12 +1330,14 @@ int linenoiseHistorySave(const char *filename) { * If the file exists and the operation succeeded 0 is returned, otherwise * on error -1 is returned. */ int linenoiseHistoryLoad(const char *filename) { - FILE *fp = fopen(filename,"r"); + File file; + file.open(filename, "r"); char buf[LINENOISE_MAX_LINE]; + if (file.file == NULL) { + return -1; + } - if (fp == NULL) return -1; - - while (fgets(buf,LINENOISE_MAX_LINE,fp) != NULL) { + while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) { char *p; p = strchr(buf,'\r'); @@ -1345,7 +1345,6 @@ int linenoiseHistoryLoad(const char *filename) { if (p) *p = '\0'; linenoiseHistoryAdd(buf); } - fclose(fp); return 0; } #endif diff --git a/examples/run/linenoise.cpp/linenoise.h b/examples/run/linenoise.cpp/linenoise.h index 3e25f4de3..a14ec6c74 100644 --- a/examples/run/linenoise.cpp/linenoise.h +++ b/examples/run/linenoise.cpp/linenoise.h @@ -45,6 +45,7 @@ extern "C" { #endif #include /* For size_t. */ +#include extern const char *linenoiseEditMore; @@ -69,10 +70,23 @@ struct linenoiseState { int history_index; /* The history index we are currently editing. */ }; -typedef struct linenoiseCompletions { - size_t len; - char **cvec; -} linenoiseCompletions; +struct linenoiseCompletions { + size_t len = 0; + char ** cvec = nullptr; + bool to_free = true; + + ~linenoiseCompletions() { + if (!to_free) { + return; + } + + for (size_t i = 0; i < len; ++i) { + free(cvec[i]); + } + + free(cvec); + } +}; /* Non blocking API. */ int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index dd9ea79e8..92a49eb74 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -28,6 +28,7 @@ #include "json.hpp" #include "linenoise.cpp/linenoise.h" #include "llama-cpp.h" +#include "chat-template.hpp" #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) [[noreturn]] static void sigint_handler(int) { @@ -105,6 +106,7 @@ class Opt { llama_model_params model_params; std::string model_; std::string user; + bool use_jinja = false; int context_size = -1, ngl = -1; float temperature = -1; bool verbose = false; @@ -145,7 +147,8 @@ class Opt { if (handle_option_with_value(argc, argv, i, context_size) == 1) { return 1; } - } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) { + } else if (options_parsing && + (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) { if (handle_option_with_value(argc, argv, i, ngl) == 1) { return 1; } @@ -156,6 +159,8 @@ class Opt { } else if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { verbose = true; + } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { + use_jinja = true; } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { help = true; return 0; @@ -190,7 +195,7 @@ class Opt { "Options:\n" " -c, --context-size \n" " Context size (default: %d)\n" - " -n, --ngl \n" + " -n, -ngl, --ngl \n" " Number of GPU layers (default: %d)\n" " --temp \n" " Temperature (default: %.1f)\n" @@ -630,20 +635,20 @@ class LlamaData { return path.substr(pos + 1); } - int remove_proto(std::string & model_) { - const std::string::size_type pos = model_.find("://"); + int rm_until_substring(std::string & model_, const std::string & substring) { + const std::string::size_type pos = model_.find(substring); if (pos == std::string::npos) { return 1; } - model_ = model_.substr(pos + 3); // Skip past "://" + model_ = model_.substr(pos + substring.size()); // Skip past the substring return 0; } int resolve_model(std::string & model_) { int ret = 0; if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - remove_proto(model_); + rm_until_substring(model_, "://"); return ret; } @@ -652,13 +657,16 @@ class LlamaData { const std::vector headers = { "--header", "Accept: application/vnd.docker.distribution.manifest.v2+json" }; if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { - remove_proto(model_); + rm_until_substring(model_, "://"); + ret = huggingface_dl(model_, headers, bn); + } else if (string_starts_with(model_, "hf.co/")) { + rm_until_substring(model_, "hf.co/"); ret = huggingface_dl(model_, headers, bn); } else if (string_starts_with(model_, "ollama://")) { - remove_proto(model_); + rm_until_substring(model_, "://"); ret = ollama_dl(model_, headers, bn); } else if (string_starts_with(model_, "https://")) { - download(model_, headers, bn, true); + ret = download(model_, headers, bn, true); } else { ret = ollama_dl(model_, headers, bn); } @@ -713,13 +721,31 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(LlamaData & llama_data, const bool append) { +static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { + if (use_jinja) { + json messages = json::array(); + for (const auto & msg : llama_data.messages) { + messages.push_back({ + {"role", msg.role}, + {"content", msg.content}, + }); + } + try { + auto result = tmpl.apply(messages, /* tools= */ json(), append); + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); + } catch (const std::exception & e) { + printe("failed to render the chat template: %s\n", e.what()); + return -1; + } + } int result = llama_chat_apply_template( - llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, + tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), + result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } @@ -729,10 +755,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) { // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens) { - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector & prompt_tokens, const LlamaData & llama_data) { + const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0; + + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { printe("failed to tokenize the prompt\n"); return -1; @@ -778,7 +806,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens) < 0) { + if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { return 1; } @@ -869,8 +897,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { - const int new_len = apply_chat_template(llama_data, append); +static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); return -1; @@ -929,9 +957,11 @@ static int get_user_input(std::string & user_input, const std::string & user) { } // Main chat loop function -static int chat_loop(LlamaData & llama_data, const std::string & user) { +static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); + auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); + GGML_ASSERT(chat_templates.template_default); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -942,7 +972,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -957,7 +987,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { return 1; } } @@ -1017,7 +1047,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user)) { + if (chat_loop(llama_data, opt.user, opt.use_jinja)) { return 1; } diff --git a/examples/server/README.md b/examples/server/README.md index 1f0a27d96..5022de672 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | - +| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | **Example-specific params** diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 26f3583bd..582ccc0d3 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d1e8ee829..b1cde2d7f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -267,6 +267,11 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + if (data.contains("lora")) { if (data.at("lora").is_array()) { params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); @@ -1422,6 +1427,10 @@ struct server_queue { int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } QUE_DBG("new task, id = %d, front = %d\n", task.id, front); if (front) { queue_tasks.push_front(std::move(task)); @@ -1439,6 +1448,10 @@ struct server_queue { if (task.id == -1) { task.id = id++; } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); if (front) { queue_tasks.push_front(std::move(task)); @@ -1539,6 +1552,20 @@ struct server_queue { } } } + +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); + } }; struct server_response { @@ -1574,6 +1601,12 @@ struct server_response { std::unique_lock lock(mutex_results); waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); } void remove_waiting_task_ids(const std::unordered_set & id_tasks) { @@ -1593,7 +1626,7 @@ struct server_response { return !queue_results.empty(); }); - for (int i = 0; i < (int) queue_results.size(); i++) { + for (size_t i = 0; i < queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); @@ -1610,12 +1643,6 @@ struct server_response { server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { while (true) { std::unique_lock lock(mutex_results); - bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{ - return !queue_results.empty(); - }); - if (!cr_res) { - return nullptr; - } for (int i = 0; i < (int) queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { @@ -1624,6 +1651,11 @@ struct server_response { return res; } } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (cr_res == std::cv_status::timeout) { + return nullptr; + } } // should never reach here @@ -1688,6 +1720,8 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + common_chat_templates chat_templates; + ~server_context() { // Clear any sampling context for (server_slot & slot : slots) { @@ -1728,13 +1762,16 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - if (!params_base.speculative.model.empty()) { + if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); auto params_dft = params_base; params_dft.devices = params_base.speculative.devices; + params_dft.hf_file = params_base.speculative.hf_file; + params_dft.hf_repo = params_base.speculative.hf_repo; params_dft.model = params_base.speculative.model; + params_dft.model_url = params_base.speculative.model_url; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; @@ -1762,16 +1799,44 @@ struct server_context { // force F16 KV cache for the draft model for extra performance cparams_dft.type_k = GGML_TYPE_F16; cparams_dft.type_v = GGML_TYPE_F16; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); } + chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + GGML_ASSERT(chat_templates.template_default.get() != nullptr); + return true; } - bool validate_builtin_chat_template() const { + bool validate_builtin_chat_template(bool use_jinja) const { llama_chat_message chat[] = {{"user", "test"}}; - const char * tmpl = llama_model_chat_template(model); - const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); - return chat_res > 0; + + if (use_jinja) { + auto templates = common_chat_templates_from_model(model, ""); + GGML_ASSERT(templates.template_default); + try { + templates.template_default->apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + if (templates.template_tool_use) { + templates.template_tool_use->apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + } + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + return false; + } + } else { + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + return chat_res > 0; + } } void init() { @@ -2338,8 +2403,8 @@ struct server_context { server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; - cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(task); } // push to beginning of the queue, so it has highest priority queue_tasks.post(cancel_tasks, true); @@ -3656,9 +3721,12 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", common_get_builtin_chat_template(ctx_server.model) }, + { "chat_template", ctx_server.chat_templates.template_default->source() }, { "build_info", build_info }, }; + if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { + data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); + } res_ok(res, data); }; @@ -3886,7 +3954,10 @@ int main(int argc, char ** argv) { return; } - json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + auto body = json::parse(req.body); + const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; + json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, @@ -4296,7 +4367,7 @@ int main(int argc, char ** argv) { // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server.validate_builtin_chat_template()) { + if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) { LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } @@ -4304,8 +4375,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), - common_chat_format_example(ctx_server.model, params.chat_template).c_str()); + ctx_server.chat_templates.template_default->source().c_str(), + common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index b15dba6eb..2e15348dc 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -4,22 +4,26 @@ from utils import * server = ServerPreset.tinyllama2() - -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): global server + server.jinja = jinja + server.chat_template = chat_template server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 9d1a7a5b0..9964db2f9 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -72,13 +72,14 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None - response_format: str | None = None lora_files: List[str] | None = None disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + jinja: bool | None = None chat_template: str | None = None + chat_template_file: str | None = None # session variables process: subprocess.Popen | None = None @@ -169,8 +170,12 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.jinja: + server_args.append("--jinja") if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 699480f90..c5987250c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -16,6 +16,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "minja.hpp" +#include "chat-template.hpp" #include #include @@ -349,7 +351,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.push_back({role, content}); } - const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) { return llama_params; } -static json oaicompat_chat_completion_params_parse( - const struct llama_model * model, - const json & body, /* openai api json semantics */ - const std::string & chat_template) { +static json oaicompat_completion_params_parse( + const json & body, /* openai api json semantics */ + const common_chat_template & tmpl, + bool use_jinja) +{ json llama_params; - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + + if (has_tools) { + if (use_jinja) { + LOG_WRN("tools param is not fully supported yet\n"); + } else { + throw std::runtime_error("tools param requires --jinja flag"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse( } } + // Apply chat template to the list of messages + if (use_jinja) { + llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + } else { + llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + } + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { @@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; + static const std::vector unsupported_params { "tool_choice" }; for (const auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html index 2180ef4ad..d3893ea4e 100644 --- a/examples/server/webui/index.html +++ b/examples/server/webui/index.html @@ -141,6 +141,7 @@ :msg="pendingMsg" :key="pendingMsg.id" :is-generating="isGenerating" + :show-thought-in-progress="config.showThoughtInProgress" :edit-user-msg-and-regenerate="() => {}" :regenerate-msg="() => {}"> @@ -202,6 +203,20 @@ + +
+ Reasoning models +
+
+ + Expand though process by default for generating message +
+
+ + Exclude thought process when sending request to API (Recommended for DeepSeek-R1) +
+
+
Advanced config @@ -261,7 +276,17 @@
- +
+ + + + Thinking + + Thought Process + + +
+
`; } + +/** + * filter out redundant fields upon sending to API + * @param {Array} messages + * @returns {Array} + */ +function normalizeMsgsForAPI(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.content, + }; + }); +} + +/** + * recommended for DeepsSeek-R1, filter out content between and tags + * @param {Array} messages + * @returns {Array} + */ +function filterThoughtFromMsgs(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.role === 'assistant' + ? msg.content.split('').at(-1).trim() + : msg.content, + }; + }); +} diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 26422601d..c5534cc13 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -95,13 +95,15 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); // helper function to evaluate a prompt and generate a response - auto generate = [&](const std::string & prompt, bool is_first) { + auto generate = [&](const std::string & prompt) { std::string response; + const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0; + // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); std::vector prompt_tokens(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { GGML_ABORT("failed to tokenize the prompt\n"); } @@ -161,7 +163,7 @@ int main(int argc, char ** argv) { break; } - const char * tmpl = llama_model_chat_template(model); + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); @@ -180,7 +182,7 @@ int main(int argc, char ** argv) { // generate a response printf("\033[33m"); - std::string response = generate(prompt, prev_len == 0); + std::string response = generate(prompt); printf("\n\033[0m"); // add the response to the messages diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 185079aa4..bbabb14de 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -58,7 +58,8 @@ else() set(GGML_BLAS_VENDOR_DEFAULT "Generic") endif() -if (CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") set(GGML_NATIVE_DEFAULT OFF) else() set(GGML_NATIVE_DEFAULT ON) @@ -153,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0ed92b3ff..9e627da8c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); } @@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 35a1c876c..2ccb4b472 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c7b6be4e2..ce4b9cfb5 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s template static __global__ void k_repeat_back( - const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2) { + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { - const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; - const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; - const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; if (tid0 >= ne0) { return; } T sum = 0; - for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { - for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { - for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { - sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } } } } - dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } template @@ -274,12 +279,14 @@ struct bin_bcast_cuda { template static void repeat_back_cuda( - const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); - k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); } template @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_can_repeat(dst, src0)); cudaStream_t stream = ctx.stream(); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - GGML_ASSERT(src0->ne[3] == 1); + GGML_TENSOR_UNARY_OP_LOCALS; - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(ne2*ne3 <= (1 << 15)); + + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; switch (dst->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; - repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); } break; default: { GGML_ASSERT(false); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2c0a56226..bb6120568 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -131,6 +131,10 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif // GGML_CUDA_F16 +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL @@ -588,7 +592,7 @@ struct ggml_tensor_extra_gpu { }; -#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7fd1fc853..85178abd2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -62,7 +62,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails - cudaGetDevice(&id); + (void)cudaGetDevice(&id); GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); @@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -309,6 +309,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { size_t pool_used = 0; size_t pool_size = 0; size_t granularity; +#if defined(GGML_USE_HIP) + std::vector> mappings; +#endif explicit ggml_cuda_pool_vmm(int device) : device(device), @@ -317,7 +320,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { ~ggml_cuda_pool_vmm() { if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); } } @@ -350,7 +360,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } // map at the end of the pool - CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0)); + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif // the memory allocation handle is no longer needed after mapping CU_CHECK(cuMemRelease(handle)); @@ -360,7 +374,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); // add to the pool pool_size += reserve_size; @@ -372,7 +386,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_addr != 0); - void * ptr = (void *) (pool_addr + pool_used); + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); *actual_size = size; pool_used += size; @@ -391,17 +405,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -547,7 +561,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -962,7 +976,7 @@ static void * ggml_cuda_host_malloc(size_t size) { cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; @@ -1082,7 +1096,9 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1103,28 +1119,38 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + if (compute_capability == GGML_CUDA_CC_CDNA) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); @@ -1197,7 +1223,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } else { cudaError_t err = cudaDeviceDisablePeerAccess(id_other); @@ -1205,7 +1231,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } } @@ -1613,10 +1639,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cudaDataType_t cu_data_type = CUDA_R_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; @@ -1645,6 +1667,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } + if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -2438,7 +2466,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto if (stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); + (void)cudaGetLastError(); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -2493,14 +2521,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; +#ifdef __HIP_PLATFORM_AMD__ + hipGraphNode_t errorNode; + hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#else cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#endif if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); #endif + // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate - cudaGetLastError(); + (void)cudaGetLastError(); CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); cuda_ctx->cuda_graph->instance = nullptr; CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); @@ -2728,7 +2762,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); @@ -2748,7 +2782,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { cudaError_t err = cudaHostUnregister(buffer); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); } } @@ -3002,7 +3036,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_REPEAT_BACK: - return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; @@ -3216,7 +3250,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FORCE_CUBLAS", "1" }); #endif - #ifdef GGML_CUDA_NO_VMM + #ifndef GGML_USE_VMM features.push_back({ "NO_VMM", "1" }); #endif diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e3b912d87..4fb466ca0 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA + if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: nwarps = 4; @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda( break; } } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 73e3e2c47..c9b2b699c 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); @@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK( cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index c905b15d7..8594093f0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -19,6 +19,12 @@ #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -74,6 +80,21 @@ #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget @@ -81,6 +102,28 @@ #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap() do { abort(); __builtin_unreachable(); } while(0) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index d090ba9bd..ecc3bc66d 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() +if (GGML_HIP_GRAPHS) + add_compile_definitions(GGML_HIP_GRAPHS) +endif() + +if (GGML_HIP_NO_VMM) + add_compile_definitions(GGML_HIP_NO_VMM) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8ba43904d..44f04c909 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int row = 2*r0 + sgitg; + if (row >= args.ne0) { + return; + } + const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.5f; @@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 63da2b86b..3d0c46578 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -181,7 +181,7 @@ struct ggml_backend_rpc_context { struct ggml_backend_rpc_buffer_context { std::shared_ptr sock; - std::unordered_map base_cache; + void * base_ptr; uint64_t remote_ptr; }; @@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { - return ctx->base_cache[buffer]; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; } rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; rpc_msg_buffer_get_base_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - void * base_ptr = reinterpret_cast(response.base_ptr); - ctx->base_cache[buffer] = base_ptr; - return base_ptr; + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; } static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { @@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr}, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, response.remote_size); return buffer; } else { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 437e9cdcc..a9d6b923c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -29,8 +29,6 @@ #include "ggml-vulkan-shaders.hpp" -#define VK_API_VERSION VK_API_VERSION_1_2 - #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define VK_VENDOR_ID_AMD 0x1002 @@ -87,6 +85,10 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; }; typedef std::shared_ptr vk_pipeline; @@ -188,8 +190,11 @@ struct vk_device_struct { bool mul_mat_id_m; bool mul_mat_id_s; - vk_matmul_pipeline pipeline_matmul_f32; - vk_matmul_pipeline pipeline_matmul_f32_f16; + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_pipeline pipeline_matmul_split_k_reduce; @@ -197,7 +202,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; @@ -778,13 +783,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - pipeline = std::make_shared(); - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); @@ -867,6 +865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + pipeline->compiled = true; { std::lock_guard guard(device->mutex); @@ -877,12 +876,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::lock_guard guard(compile_count_mutex); assert(compile_count > 0); compile_count--; - - // "Progress bar" for shader compiles - static uint32_t total_compile_count = 0; - if ((total_compile_count++ % 10) == 0) { - std::cerr << "."; - } } compile_count_cond.notify_all(); } @@ -908,6 +901,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); device->pipeline_descriptor_set_requirements[pipeline->name] += n; + if (!pipeline->compiled) { + pipeline->needed = true; + device->need_compiles = true; + } } static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { @@ -1390,8 +1387,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::cerr << "ggml_vulkan: Compiling shaders"; - // some shaders have a minimum subgroup size const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1529,15 +1524,33 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - device->pipeline_matmul_f32 = std::make_shared(); - device->pipeline_matmul_f32_f16 = std::make_shared(); - - device->pipeline_matmul_id_f32 = std::make_shared(); + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -1614,11 +1627,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1631,21 +1640,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -1682,31 +1688,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1716,31 +1722,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } #undef CREATE_MM2 @@ -2021,7 +2027,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); @@ -2059,7 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &c : compiles) { c.wait(); } - std::cerr << "Done!" << std::endl; + device->need_compiles = false; } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); @@ -2287,6 +2293,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2662,7 +2676,14 @@ void ggml_vk_instance_init() { vk_instance_initialized = true; - vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); @@ -2972,7 +2993,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { case GGML_TYPE_Q4_0: @@ -3812,8 +3833,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub src1_uma = d_Qy != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src1); @@ -4393,8 +4415,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ids_uma = d_ids != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -4404,7 +4429,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { - GGML_ABORT("fatal error"); + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented @@ -7645,6 +7671,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg for (int i = 0; i < cgraph->n_nodes; i++) { ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_vk_preallocate_buffers(ctx); ggml_pipeline_allocate_descriptor_sets(ctx->device); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b5..26d8bc22a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ca3a59b8f..3735d0dbb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -166,7 +166,7 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; + coopmat Q; coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index cbfa5dce1..57f9e7245 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#define MAT_A_TYPE float16_t #include "dequant_funcs_cm2.comp" #else #define DECODEFUNCA -#define MAT_A_TYPE A_TYPE #endif -#define MAT_B_TYPE B_TYPE - #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -236,16 +232,13 @@ void main() { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopmat mat_a_ft = coopmat(mat_a); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); - coopmat mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } else #endif // !defined(MUL_MAT_ID) @@ -261,10 +254,8 @@ void main() { [[dont_unroll]] for (uint block_k = start_k; block_k < end_k; block_k += BK) { - coopmat mat_a; - coopmat mat_b; - coopmat mat_a_ft; - coopmat mat_b_ft; + coopmat mat_a; + coopmat mat_b; // Clamping is expensive, so detect different code paths for each combination // of A and B needing clamping. @@ -281,16 +272,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID @@ -298,16 +285,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b7890ef1e..e9c6cb9d4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -17,13 +17,13 @@ #include #include #include +#include #include #include #ifdef _WIN32 #include #include // For _mkdir on Windows - #include // For std::replace on w64devkit #else #include #include @@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } if (tname != "f16" && tname != "f32") { string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); @@ -499,6 +502,7 @@ void write_output_files() { fprintf(hdr, "#include \n\n"); fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { const std::string& name = pair.first; #ifdef _WIN32 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b1d0d4913..92c4294c5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5339,7 +5339,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MUL: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } if (src1_needs_grads) { struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); @@ -5431,21 +5431,25 @@ static void ggml_compute_backward( // src1.shape [n,p,qq,rr] if (src0_needs_grads) { - struct ggml_tensor * s1_tg = + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, @@ -5514,7 +5518,9 @@ static void ggml_compute_backward( if (src0_needs_grads) { GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); GGML_ASSERT(ggml_is_contiguous(grad)); - ggml_add_or_set(ctx, cgraph, isrc0, grad); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); } } break; case GGML_OP_RESHAPE: { diff --git a/include/llama.h b/include/llama.h index 298b8d1bc..3b75e7607 100644 --- a/include/llama.h +++ b/include/llama.h @@ -510,7 +510,8 @@ extern "C" { LLAMA_API uint64_t llama_model_size(const struct llama_model * model); // Get the default chat template. Returns nullptr if not available - LLAMA_API const char * llama_model_chat_template(const struct llama_model * model); + // If name is NULL, returns the default chat template + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp new file mode 100644 index 000000000..9baf7d77a --- /dev/null +++ b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.out b/models/ggml-vocab-deepseek-r1-qwen.gguf.out new file mode 100644 index 000000000..18b4b45cd --- /dev/null +++ b/models/ggml-vocab-deepseek-r1-qwen.gguf.out @@ -0,0 +1,46 @@ + 1122 220 19 220 26062 3951 + 37 50753 261 + + 220 + 256 + 262 + 197 + 198 + 271 + 1406 + 1572 + 9707 1879 + 21927 1879 + 9707 4337 + 21927 4337 + 21927 4337 0 + 9707 11 1879 0 + 21927 11 1879 0 + 419 374 11162 99 247 13 10821 + 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 + 78762 14144 1456 13073 63471 33594 3038 133178 79012 + 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 + 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 + 9707 + 21927 + 220 21927 + 256 21927 + 262 21927 + 262 21927 198 262 21927 + 320 + 198 284 + 6 11385 + 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 + 17085 2928 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 34 90063 128324 + 2560 2347 + 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/scripts/get_hf_chat_template.py b/scripts/get_hf_chat_template.py new file mode 100755 index 000000000..23bb1de59 --- /dev/null +++ b/scripts/get_hf_chat_template.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +''' + Fetches the Jinja chat template of a HuggingFace model. + If a model has multiple chat templates, you can specify the variant name. + + Syntax: + ./scripts/get_hf_chat_template.py model_id [variant] + + Examples: + ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct +''' + +import json +import re +import sys + + +def get_hf_chat_template(model_id, variant=None): + try: + # Use huggingface_hub library if available. + # Allows access to gated models if the user has access and ran `huggingface-cli login`. + from huggingface_hub import hf_hub_download + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + except ImportError: + import requests + assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" + response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + if response.status_code == 401: + raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') + response.raise_for_status() + config_str = response.text + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json + # (Remove extra '}' near the end of the file) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + return chat_template + else: + variants = { + ct['name']: ct['template'] + for ct in chat_template + } + + def format_variants(): + return ', '.join(f'"{v}"' for v in variants.keys()) + + if variant is None: + if 'default' not in variants: + raise Exception(f'Please specify a chat template variant (one of {format_variants()})') + variant = 'default' + sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n') + elif variant not in variants: + raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") + + return variants[variant] + + +def main(args): + if len(args) < 1: + raise ValueError("Please provide a model ID and an optional variant name") + model_id = args[0] + variant = None if len(args) < 2 else args[1] + + template = get_hf_chat_template(model_id, variant) + sys.stdout.write(template) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aeb75bf3e..e1b02e4c0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,7 +29,7 @@ add_library(llama unicode-data.cpp ) -target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PUBLIC . ../include ../common) target_compile_features (llama PUBLIC cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d7d277e72..a7260f495 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -179,6 +179,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1443,10 +1444,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/src/llama-arch.h b/src/llama-arch.h index 349844790..122fdcebe 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -177,6 +177,7 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -335,9 +336,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 1347ec156..5c19bab24 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -152,7 +152,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; - } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 57c6e4f51..b716630a8 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef __has_include #if __has_include() diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2d23a8d3..031b4c30b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2203,6 +2203,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; case LLM_ARCH_PLAMO: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3911,8 +3955,10 @@ uint64_t llama_model_size(const struct llama_model * model) { return model->size(); } -const char * llama_model_chat_template(const struct llama_model * model) { - const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)); +const char * llama_model_chat_template(const struct llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { return nullptr; } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 9a680aed4..0782d3a41 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1523,7 +1523,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; clean_spaces = false; } else if ( - tokenizer_pre == "qwen2") { + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 74d1bee39..468016403 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1302,6 +1302,59 @@ struct test_repeat : public test_case { } }; +// GGML_OP_REPEAT_BACK +struct test_repeat_back : public test_case { + const ggml_type type; + const std::array ne; + const std::array nr; + const bool v; // whether src is a noncontiguous view + + std::string vars() override { + return VARS_TO_STR4(type, ne, nr, v); + } + + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 2; + } + + test_repeat_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {8, 6, 4, 2}, + std::array nr = {2, 2, 2, 2}, + bool v = false) + : type(type), ne(ne), nr(nr), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(src, "src"); + + if (v) { + GGML_ASSERT(ne[0] % 2 == 0); + GGML_ASSERT(ne[1] % 2 == 0); + GGML_ASSERT(ne[2] % 2 == 0); + GGML_ASSERT(ne[3] % 2 == 0); + GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1); + GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1); + GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1); + GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1); + + const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2; + const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2; + const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2; + const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2; + + src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0); + } + + ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(target, "target"); + + ggml_tensor * out = ggml_repeat_back(ctx, src, target); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_DUP struct test_dup : public test_case { const ggml_type type; @@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case { return 5e-4; } + int64_t grad_nmax() override { + return 20000; + } + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; @@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case { a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case { } else { a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); } @@ -3798,6 +3863,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); } + for (bool view : {false, true}) { + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + } + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); @@ -3909,38 +3984,35 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); - for (int i = 1; i < 9; ++i) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + for (ggml_type type_a : all_types) { + for (int i = 1; i < 10; ++i) { + test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + } } #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { // test cases without permutation - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2})); // test cases with permutation test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 5ec318bd4..190643136 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,6 +7,16 @@ #include "llama.h" #include "common.h" +#include "chat-template.hpp" + +static std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} int main(void) { std::vector conversation { @@ -21,156 +31,228 @@ int main(void) { std::string name; std::string template_str; std::string expected_output; + std::string expected_output_jinja; + std::string bos_token = ""; + std::string eos_token = ""; + bool supported_with_jinja = true; }; std::vector test_cases { { /* .name= */ "teknium/OpenHermes-2.5-Mistral-7B", /* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", /* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ", - /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", - /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "bofenghuang/vigogne-2-70b-chat", - /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "google/gemma-7b-it", /* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", - /* .expected_output= */ "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + /* .expected_output= */ "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + /* .expected_output_jinja= */ "user\nYou are a helpful assistant\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", }, { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", - /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "openchat/openchat-3.5-0106", // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d // So we match against the included template but implement the suggested version. /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", - /* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + /* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + /* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", }, { /* .name= */ "deepseek-ai/deepseek-coder-33b-instruct", /* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", /* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "eachadea/vicuna-13b-1.1", // No template included in tokenizer_config.json, so this template likely needs to be manually set. /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", /* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Orca-Vicuna", // No template included in tokenizer_config.json, so this template likely needs to be manually set. /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", /* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "CohereForAI/c4ai-command-r-plus", /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", /* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Llama-3", /* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", /* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Phi-3-mini", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", }, { /* .name= */ "Phi-3-small", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Phi-3-medium", /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", }, { /* .name= */ "Phi-3-vision", /* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "ChatGLM3", /* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - /* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + /* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + /* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", }, { /* .name= */ "ChatGLM4", /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", /* .expected_output= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "DeepSeek-V2", /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "<|end▁of▁sentence|>", }, { /* .name= */ "ibm-granite/granite-3.0-8b-instruct", /* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", - /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", }, { /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)", /* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", /* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - /* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", + /* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", + /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] You are a helpful assistant\n\nAnother question[/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - /* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", + /* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", + /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]You are a helpful assistant\n\nAnother question[/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant [INST] Another question[/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "ai-sage/GigaChat-20B-A3B-instruct", /* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}", /* .expected_output= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", + /* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context }, { /* .name= */ "Infinigence/Megrez-3B-Instruct", /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}", /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "phi-4", /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}", /* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, }; std::vector formatted_chat(1024); @@ -190,6 +272,7 @@ int main(void) { // test invalid chat template res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); + const auto add_generation_prompt = true; for (const auto & test_case : test_cases) { printf("\n\n=== %s ===\n\n", test_case.name.c_str()); @@ -198,26 +281,59 @@ int main(void) { test_case.template_str.c_str(), conversation.data(), conversation.size(), - true, + add_generation_prompt, formatted_chat.data(), formatted_chat.size() ); formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); - printf("%s\n", output.c_str()); - printf("-------------------------\n"); - assert(output == test_case.expected_output); + if (output != test_case.expected_output) { + printf("Expected:\n%s\n", test_case.expected_output.c_str()); + printf("-------------------------\n"); + printf("Actual:\n%s\n", output.c_str()); + fflush(stdout); + assert(output == test_case.expected_output); + } } + json messages = json::array(); + for (const auto & msg : conversation) { + messages.push_back({ + {"role", msg.role}, + {"content", msg.content}, + }); + } + for (const auto & test_case : test_cases) { + if (!test_case.supported_with_jinja) { + continue; + } + printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); + try { + minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); + auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); + auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); + if (output != expected_output) { + printf("Expected:\n%s\n", expected_output.c_str()); + printf("-------------------------\n"); + printf("Actual:\n%s\n", output.c_str()); + fflush(stdout); + assert(output == expected_output); + } + } catch (const std::exception & e) { + printf("ERROR: %s\n", e.what()); + assert(false); + } + } // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; common_chat_msg sys_msg{"system", "You are a helpful assistant"}; - auto fmt_sys = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); - printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_sys = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; @@ -241,9 +357,10 @@ int main(void) { chat2.push_back({"assistant", "I am assistant"}); common_chat_msg new_msg{"user", "How are you"}; - auto fmt_single = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true); - printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_single = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; @@ -258,7 +375,5 @@ int main(void) { assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); - printf("Test chat templates: OK\n"); - return 0; }