diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml index f10b3a2b2..b85bf5741 100644 --- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -65,12 +65,22 @@ body: If possible, please do a git bisect and identify the exact commit that introduced the bug. validations: required: false + - type: textarea + id: command + attributes: + label: Compile command + description: > + Please provide the exact command you used to compile llama.cpp. For example: `cmake -B ...`. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true - type: textarea id: logs attributes: label: Relevant log output description: > - Please copy and paste any relevant log output, including the command that you entered and any generated text. + Please copy and paste any relevant log output, including any generated text. This will be automatically formatted into code, so no need for backticks. render: shell validations: diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml index d157ea307..1904e31fd 100644 --- a/.github/ISSUE_TEMPLATE/019-bug-misc.yml +++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml @@ -52,6 +52,16 @@ body: - Other (Please specify in the next section) validations: required: false + - type: textarea + id: command + attributes: + label: Command line + description: > + Please provide the exact commands you entered, if applicable. For example: `llama-server -m ... -c ...`, `llama-cli -m ...`, etc. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false - type: textarea id: info attributes: @@ -74,7 +84,7 @@ body: attributes: label: Relevant log output description: > - If applicable, please copy and paste any relevant log output, including the command that you entered and any generated text. + If applicable, please copy and paste any relevant log output, including any generated text. This will be automatically formatted into code, so no need for backticks. render: shell validations: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a377eff38..c85999b89 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -60,8 +60,7 @@ jobs: -DLLAMA_CURL=ON \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ - -DGGML_RPC=ON \ - -DBUILD_SHARED_LIBS=OFF + -DGGML_RPC=ON cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -123,8 +122,7 @@ jobs: -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL=OFF \ - -DGGML_RPC=ON \ - -DBUILD_SHARED_LIBS=OFF + -DGGML_RPC=ON cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -181,7 +179,7 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF + cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -651,23 +649,23 @@ jobs: matrix: include: - build: 'noavx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF' - build: 'avx2-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON' - build: 'avx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF' - build: 'avx512-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON' - build: 'openblas-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DBUILD_SHARED_LIBS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - build: 'kompute-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' - build: 'vulkan-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON' - build: 'llvm-arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - build: 'msvc-arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - build: 'llvm-arm64-opencl-adreno' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' @@ -914,7 +912,7 @@ jobs: shell: cmd run: | call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON -DGGML_RPC=ON + cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release @@ -1239,7 +1237,7 @@ jobs: - name: Create release id: create_release - uses: anzz1/action-create-release@v1 + uses: ggml-org/action-create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 41f1a89ee..d71f1eb38 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -97,10 +97,9 @@ jobs: GITHUB_BRANCH_NAME: ${{ github.head_ref || github.ref_name }} GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' - # https://github.com/jlumbroso/free-disk-space/tree/54081f138730dfa15788a46383842cd2f914a1be#example - name: Free Disk Space (Ubuntu) if: ${{ matrix.config.free_disk_space == true }} - uses: jlumbroso/free-disk-space@main + uses: ggml-org/free-disk-space@v1.3.1 with: # this might remove tools that are actually needed, # if set to "true" but frees about 6 GB diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml index ae86e9927..f02b7c219 100644 --- a/.github/workflows/editorconfig.yml +++ b/.github/workflows/editorconfig.yml @@ -23,5 +23,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: editorconfig-checker/action-editorconfig-checker@main + - uses: editorconfig-checker/action-editorconfig-checker@v2 + with: + version: v3.0.3 - run: editorconfig-checker diff --git a/CODEOWNERS b/CODEOWNERS index adeba5395..72d594b46 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,11 @@ # collaborators can optionally add themselves here to indicate their availability for reviewing related PRs /ci/ @ggerganov -/.devops/ @ngxson +/.devops/*.Dockerfile @ngxson /examples/server/ @ngxson +/ggml/src/ggml-cuda/fattn* @JohannesGaessler +/ggml/src/ggml-cuda/mmq.* @JohannesGaessler +/ggml/src/ggml-cuda/mmv.* @JohannesGaessler +/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler +/ggml/src/ggml-opt.cpp @JohannesGaessler +/ggml/src/gguf.cpp @JohannesGaessler diff --git a/README.md b/README.md index d6d1958c8..a71015256 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) - [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) - [x] [Phi models](https://huggingface.co/models?search=microsoft/phi) +- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003) - [x] [GPT-2](https://huggingface.co/gpt2) - [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118) - [x] [InternLM2](https://huggingface.co/models?search=internlm2) @@ -201,6 +202,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp - [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs - [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly +- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server diff --git a/common/arg.cpp b/common/arg.cpp index deb113786..27886b84e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -22,6 +22,11 @@ common_arg & common_arg::set_examples(std::initializer_list return *this; } +common_arg & common_arg::set_excludes(std::initializer_list excludes) { + this->excludes = std::move(excludes); + return *this; +} + common_arg & common_arg::set_env(const char * env) { help = help + "\n(env: " + env + ")"; this->env = env; @@ -37,6 +42,10 @@ bool common_arg::in_example(enum llama_example ex) { return examples.find(ex) != examples.end(); } +bool common_arg::is_exclude(enum llama_example ex) { + return excludes.find(ex) != excludes.end(); +} + bool common_arg::get_value_from_env(std::string & output) { if (env == nullptr) return false; char * value = std::getenv(env); @@ -420,7 +429,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex * - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example */ auto add_opt = [&](common_arg arg) { - if (arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) { + if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) { ctx_arg.options.push_back(std::move(arg)); } }; @@ -649,7 +658,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.prompt = value; } - )); + ).set_excludes({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--no-perf"}, string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), @@ -673,7 +682,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.prompt.pop_back(); } } - )); + ).set_excludes({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--in-file"}, "FNAME", "an input file (repeat to specify multiple files)", @@ -700,7 +709,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.prompt = ss.str(); fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), value.c_str()); } - )); + ).set_excludes({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-e", "--escape"}, string_format("process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false"), @@ -1512,7 +1521,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)", [](common_params & params, const std::string & value) { - params.lora_adapters.push_back({ std::string(value), 1.0 }); + params.lora_adapters.push_back({ std::string(value), 1.0, nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); @@ -1520,7 +1529,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora-scaled"}, "FNAME", "SCALE", "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", [](common_params & params, const std::string & fname, const std::string & scale) { - params.lora_adapters.push_back({ fname, std::stof(scale) }); + params.lora_adapters.push_back({ fname, std::stof(scale), nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); diff --git a/common/arg.h b/common/arg.h index a6700d323..49ab8667b 100644 --- a/common/arg.h +++ b/common/arg.h @@ -12,6 +12,7 @@ struct common_arg { std::set examples = {LLAMA_EXAMPLE_COMMON}; + std::set excludes = {}; std::vector args; const char * value_hint = nullptr; // help text or example for arg value const char * value_hint_2 = nullptr; // for second arg value @@ -53,9 +54,11 @@ struct common_arg { ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} common_arg & set_examples(std::initializer_list examples); + common_arg & set_excludes(std::initializer_list excludes); common_arg & set_env(const char * env); common_arg & set_sparam(); bool in_example(enum llama_example ex); + bool is_exclude(enum llama_example ex); bool get_value_from_env(std::string & output); bool has_value_from_env(); std::string to_string(); diff --git a/common/common.cpp b/common/common.cpp index 20be92911..86e4e1e24 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2,6 +2,9 @@ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif +#include "ggml.h" +#include "gguf.h" + #include "common.h" #include "log.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: @@ -18,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -62,7 +66,9 @@ #ifdef __linux__ #include #elif defined(_WIN32) -#define PATH_MAX MAX_PATH +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif #else #include #endif @@ -843,7 +849,7 @@ struct common_init_result common_init_from_params(common_params & params) { } else if (!params.model_url.empty()) { model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); } if (model == NULL) { @@ -870,7 +876,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!ok) { - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -881,14 +887,13 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return iparams; } if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) { - LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__); - llama_free_model(model); - return iparams; + LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__); + params.ctx_shift = false; } if (!params.control_vectors.empty()) { @@ -898,7 +903,7 @@ struct common_init_result common_init_from_params(common_params & params) { const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -911,7 +916,7 @@ struct common_init_result common_init_from_params(common_params & params) { params.control_vector_layer_end); if (err) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -919,20 +924,21 @@ struct common_init_result common_init_from_params(common_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - common_lora_adapter_container loaded_la; - loaded_la.path = la.path; - loaded_la.scale = la.scale; - loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); - if (loaded_la.adapter == nullptr) { + llama_lora_adapter_ptr lora; + lora.reset(llama_lora_adapter_init(model, la.path.c_str())); + if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } - iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters + + la.ptr = lora.get(); + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } + if (!params.lora_init_without_apply) { - common_lora_adapters_apply(lctx, iparams.lora_adapters); + common_lora_adapters_apply(lctx, params.lora_adapters); } if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { @@ -979,7 +985,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; } tmp.clear(); @@ -993,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) { llama_perf_context_reset(lctx); } - iparams.model = model; - iparams.context = lctx; + iparams.model.reset(model); + iparams.context.reset(lctx); return iparams; } -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) { llama_lora_adapter_clear(ctx); - for (auto & la : lora_adapters) { + for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.adapter, la.scale); + llama_lora_adapter_set(ctx, la.ptr, la.scale); } } } @@ -1148,8 +1154,7 @@ static bool common_download_file(const std::string & url, const std::string & pa #endif // Check if the file already exists locally - struct stat model_file_info; - auto file_exists = (stat(path.c_str(), &model_file_info) == 0); + auto file_exists = std::filesystem::exists(path); // If the file exists, check its JSON metadata companion file. std::string metadata_path = path + ".json"; @@ -1409,7 +1414,7 @@ struct llama_model * common_load_model_from_url( } } - return llama_load_model_from_file(local_path.c_str(), params); + return llama_model_load_from_file(local_path.c_str(), params); } struct llama_model * common_load_model_from_hf( @@ -1612,6 +1617,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector 0) { + std::vector model_template(res + 1, 0); + llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size()); + return std::string(model_template.data(), model_template.size() - 1); + } + return ""; +} + bool common_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); diff --git a/common/common.h b/common/common.h index 1d2bd932c..0d452cf0f 100644 --- a/common/common.h +++ b/common/common.h @@ -2,7 +2,7 @@ #pragma once -#include "llama.h" +#include "llama-cpp.h" #include #include @@ -27,10 +27,8 @@ struct common_lora_adapter_info { std::string path; float scale; -}; -struct common_lora_adapter_container : common_lora_adapter_info { - struct llama_lora_adapter * adapter; + struct llama_lora_adapter * ptr; }; using llama_tokens = std::vector; @@ -478,10 +476,12 @@ std::string fs_get_cache_file(const std::string & filename); // Model utils // +// note: defines object's lifetime struct common_init_result { - struct llama_model * model = nullptr; - struct llama_context * context = nullptr; - std::vector lora_adapters; + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; }; struct common_init_result common_init_from_params(common_params & params); @@ -503,7 +503,7 @@ struct llama_model * common_load_model_from_hf( const struct llama_model_params & params); // clear LoRA adapters from context, then apply new list of adapters -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora); // // Batch utils @@ -571,6 +571,9 @@ struct common_chat_msg { std::string content; }; +// Get the built-in chat template for the model. Return empty string if not present. +std::string common_get_builtin_chat_template(const struct llama_model * model); + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl); @@ -637,6 +640,10 @@ common_control_vector_data common_control_vector_load(const std::vectorsecond; int max_count_static = 0; int sum_count_static = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_static : part_static) { const llama_token token = token_count_static.first; @@ -85,10 +85,10 @@ static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram } if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { - return -1; + return LLAMA_TOKEN_NULL; } if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) { - return -1; + return LLAMA_TOKEN_NULL; } return max_token; } @@ -98,9 +98,9 @@ static llama_token try_draft( common_ngram_cache & nc_primary, const std::vector & ngrams_primary, common_ngram_cache_part & part_static, const int * min_sample_size, const int * min_percent) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; - for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) { + for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) { const common_ngram ngram_primary = ngrams_primary[i]; common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); @@ -112,7 +112,7 @@ static llama_token try_draft( int max_count_primary = 0; int max_count_static = 0; int sum_count_primary = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_primary : part_primary) { const llama_token token = token_count_primary.first; @@ -154,7 +154,7 @@ void common_ngram_cache_draft( } while ((int) draft.size()-1 < n_draft) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; common_ngram ngram_static; @@ -177,17 +177,17 @@ void common_ngram_cache_draft( } ngrams_cd.push_back(ngram_cd); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_static, ngram_static); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { break; } diff --git a/common/ngram-cache.h b/common/ngram-cache.h index 09c2b0319..dfe012abe 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -17,13 +17,13 @@ struct common_ngram { common_ngram() { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = -1; + tokens[i] = LLAMA_TOKEN_NULL; } } common_ngram(const llama_token * input, const int ngram_size) { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = i < ngram_size ? input[i] : -1; + tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL; } } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b6c15da94..5562499aa 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -687,6 +687,9 @@ class Model: if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") @@ -1764,25 +1767,19 @@ class DeciModel(Model): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab( - self.dir_model, load_merges=True, - special_token_types = ['bos', 'eos', 'eom', 'eot'] - ) - special_vocab._set_special_token("bos", 128000) - special_vocab._set_special_token("eos", 128001) - special_vocab._set_special_token("eom", 128008) - special_vocab._set_special_token("eot", 128009) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) special_vocab.add_to_gguf(self.gguf_writer) else: # DeciLM-7B self._set_vocab_llama_hf() -# self._set_vocab_gpt2() def set_gguf_parameters(self): if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_heads) assert self.block_count == len(self._ffn_dims) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) self.gguf_writer.add_head_count_kv(self._num_kv_heads) self.gguf_writer.add_head_count(self._num_heads) self.gguf_writer.add_feed_forward_length(self._ffn_dims) @@ -2565,6 +2562,63 @@ class Phi3MiniModel(Model): yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) +@Model.register("PhiMoEForCausalLM") +class PhiMoeModel(Phi3MiniModel): + model_arch = gguf.MODEL_ARCH.PHIMOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("PlamoForCausalLM") class PlamoModel(Model): model_arch = gguf.MODEL_ARCH.PLAMO @@ -3379,6 +3433,24 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("Cohere2ForCausalLM") +class Cohere2Model(Model): + model_arch = gguf.MODEL_ARCH.COHERE2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + rotary_pct = self.hparams["rotary_pct"] + hidden_size = self.hparams["hidden_size"] + num_attention_heads = self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + @Model.register("OlmoForCausalLM") @Model.register("OLMoForCausalLM") class OlmoModel(Model): @@ -3837,6 +3909,7 @@ class DeepseekModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3858,6 +3931,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3870,6 +3952,16 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index fea23ddb4..56edc64a7 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -107,6 +107,7 @@ models = [ {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index ed1014cae..6dea14a23 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -226,6 +226,9 @@ def get_base_tensor_name(lora_tensor_name: str) -> str: base_name = lora_tensor_name.replace("base_model.model.", "") base_name = base_name.replace(".lora_A.weight", ".weight") base_name = base_name.replace(".lora_B.weight", ".weight") + # models produced by mergekit-extract-lora have token embeddings in the adapter + base_name = base_name.replace(".lora_embedding_A", ".weight") + base_name = base_name.replace(".lora_embedding_B", ".weight") return base_name @@ -260,6 +263,10 @@ def parse_args() -> argparse.Namespace: "--base", type=Path, help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config", ) + parser.add_argument( + "--base-model-id", type=str, + help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')", + ) parser.add_argument( "lora_path", type=Path, help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)", @@ -290,6 +297,7 @@ if __name__ == '__main__': dir_base_model: Path | None = args.base dir_lora: Path = args.lora_path + base_model_id: str | None = args.base_model_id lora_config = dir_lora / "adapter_config.json" input_model = dir_lora / "adapter_model.safetensors" @@ -313,7 +321,10 @@ if __name__ == '__main__': lparams: dict[str, Any] = json.load(f) # load base model - if dir_base_model is None: + if base_model_id is not None: + logger.info(f"Loading base model from Hugging Face: {base_model_id}") + hparams = load_hparams_from_hf(base_model_id) + elif dir_base_model is None: if "base_model_name_or_path" in lparams: model_id = lparams["base_model_name_or_path"] logger.info(f"Loading base model from Hugging Face: {model_id}") @@ -371,11 +382,16 @@ if __name__ == '__main__': if self.lazy: tensor = LazyTorchTensor.from_eager(tensor) base_name = get_base_tensor_name(name) - is_lora_a = ".lora_A.weight" in name - is_lora_b = ".lora_B.weight" in name + # note: mergekit-extract-lora also adds token embeddings to the adapter + is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name + is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name if not is_lora_a and not is_lora_b: if ".base_layer.weight" in name: continue + # mergekit-extract-lora add these layernorm to the adapter, we need to keep them + if "_layernorm" in name or ".norm" in name: + yield (base_name, tensor) + continue logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") if ".embed_tokens.weight" in name or ".lm_head.weight" in name: logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") @@ -407,9 +423,21 @@ if __name__ == '__main__': if name == "lm_head.weight" and len(dest) == 0: raise ValueError("lm_head is present in adapter, but is ignored in base model") for dest_name, dest_data in dest: + # mergekit-extract-lora add these layernorm to the adapter + if "_norm" in dest_name: + assert dest_data.dim() == 1 + yield (dest_name, dest_data) + continue + + # otherwise, we must get the lora_A and lora_B tensors assert isinstance(dest_data, LoraTorchTensor) lora_a, lora_b = dest_data.get_lora_A_B() + # note: mergekit-extract-lora flip and transpose A and B + # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd() + if "token_embd.weight" in dest_name: + lora_a = lora_a.T + yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_b", lora_b) diff --git a/docs/build.md b/docs/build.md index 84019b204..3b0d2211d 100644 --- a/docs/build.md +++ b/docs/build.md @@ -127,6 +127,8 @@ For detailed info, please refer to [llama.cpp for SYCL](./backend/SYCL.md). This provides GPU acceleration using an NVIDIA GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from the [NVIDIA developer site](https://developer.nvidia.com/cuda-downloads). +If you are using Fedora (using Fedora Workstation, or an 'Atomic' variant such as Silverblue), or would like to set up CUDA in a toolbox, please consider our [Fedora CUDA guide](./cuda-fedora.md). Unfortunately, the process is not as simple as one might expect. + - Using `CMake`: ```bash diff --git a/docs/cuda-fedora.md b/docs/cuda-fedora.md new file mode 100644 index 000000000..b993386c8 --- /dev/null +++ b/docs/cuda-fedora.md @@ -0,0 +1,317 @@ +# Setting Up CUDA on Fedora + +In this guide we setup [Nvidia CUDA](https://docs.nvidia.com/cuda/) in a toolbox container. This guide is applicable for: +- [Fedora Workstation](https://fedoraproject.org/workstation/) +- [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/) +- [Fedora Spins](https://fedoraproject.org/spins) +- [Other Distributions](https://containertoolbx.org/distros/), including `Red Hat Enterprise Linux >= 8.`, `Arch Linux`, and `Ubuntu`. + + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [Monitoring NVIDIA CUDA Repositories](#monitoring-nvidia-cuda-repositories) +- [Using the Fedora 39 CUDA Repository](#using-the-fedora-39-cuda-repository) +- [Creating a Fedora Toolbox Environment](#creating-a-fedora-toolbox-environment) +- [Installing Essential Development Tools](#installing-essential-development-tools) +- [Adding the CUDA Repository](#adding-the-cuda-repository) +- [Installing `nvidia-driver-libs`](#installing-nvidia-driver-libs) +- [Manually Resolving Package Conflicts](#manually-resolving-package-conflicts) +- [Finalizing the Installation of `nvidia-driver-libs`](#finalizing-the-installation-of-nvidia-driver-libs) +- [Installing the CUDA Meta-Package](#installing-the-cuda-meta-package) +- [Configuring the Environment](#configuring-the-environment) +- [Verifying the Installation](#verifying-the-installation) +- [Conclusion](#conclusion) +- [Troubleshooting](#troubleshooting) +- [Additional Notes](#additional-notes) +- [References](#references) + +## Prerequisites + +- **Toolbox Installed on the Host System** `Fedora Silverblue` and `Fedora Workstation` both have toolbox by default, other distributions may need to install the [toolbox package](https://containertoolbx.org/install/). +- **NVIDIA Drivers and Graphics Card installed on Host System (optional)** To run CUDA program, such as `llama.cpp`, the host should be setup to access your NVIDIA hardware. Fedora Hosts can use the [RPM Fusion Repository](https://rpmfusion.org/Howto/NVIDIA). +- **Internet connectivity** to download packages. + +### Monitoring NVIDIA CUDA Repositories + +Before proceeding, it is advisable to check if NVIDIA has updated their CUDA repositories for your Fedora version. NVIDIA's repositories can be found at: + +- [Fedora 40 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora40/x86_64/) +- [Fedora 41 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora41/x86_64/) + +As of the latest update, these repositories do not contain the `cuda` meta-package or are missing essential components. + +### Using the Fedora 39 CUDA Repository + +Since the newer repositories are incomplete, we'll use the Fedora 39 repository: + +- [Fedora 39 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/) + +**Note:** Fedora 39 is no longer maintained, so we recommend using a toolbox environment to prevent system conflicts. + +## Creating a Fedora Toolbox Environment + +This guide focuses on Fedora hosts, but with small adjustments, it can work for other hosts. Using a Fedora 39 toolbox allows us to install the necessary packages without affecting the host system. + +**Note:** Toolbox is available for other systems, and even without Toolbox, it is possible to use Podman or Docker. + +We do not recommend installing on the host system, as Fedora 39 is out-of-maintenance, and instead you should upgrade to a maintained version of Fedora for your host. + +1. **Create a Fedora 39 Toolbox:** + + ```bash + toolbox create --image registry.fedoraproject.org/fedora-toolbox:39 --container fedora-toolbox-39-cuda + ``` + +2. **Enter the Toolbox:** + + ```bash + toolbox enter --container fedora-toolbox-39-cuda + ``` + + Inside the toolbox, you have root privileges and can install packages without affecting the host system. + +## Installing Essential Development Tools + +1. **Synchronize the DNF Package Manager:** + + ```bash + sudo dnf distro-sync + ``` + +2. **Install the Default Text Editor (Optional):** + + ```bash + sudo dnf install vim-default-editor --allowerasing + ``` + + The `--allowerasing` flag resolves any package conflicts. + +3. **Install Development Tools and Libraries:** + + ```bash + sudo dnf install @c-development @development-tools cmake + ``` + + This installs essential packages for compiling software, including `gcc`, `make`, and other development headers. + +## Adding the CUDA Repository + +Add the NVIDIA CUDA repository to your DNF configuration: + +```bash +sudo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/cuda-fedora39.repo +``` + +After adding the repository, synchronize the package manager again: + +```bash +sudo dnf distro-sync +``` + +## Installing `nvidia-driver-libs` + +Attempt to install `nvidia-driver-libs`: + +```bash +sudo dnf install nvidia-driver-libs +``` + +**Explanation:** + +- `nvidia-driver-libs` contains necessary NVIDIA driver libraries required by CUDA. +- This step might fail due to conflicts with existing NVIDIA drivers on the host system. + +## Manually Resolving Package Conflicts + +If the installation fails due to conflicts, we'll manually download and install the required packages, excluding conflicting files. + +### 1. Download the `nvidia-driver-libs` RPM + +```bash +sudo dnf download --arch x86_64 nvidia-driver-libs +``` + +You should see a file similar to: + +``` +nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +### 2. Attempt to Install the RPM + +```bash +sudo dnf install nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +**Expected Error:** + +Installation may fail with errors pointing to conflicts with `egl-gbm` and `egl-wayland`. + +**Note: It is important to carefully read the error messages to identify the exact paths that need to be excluded.** + +### 3. Download Dependencies + +```bash +sudo dnf download --arch x86_64 egl-gbm egl-wayland +``` + +### 4. Install `egl-gbm` with Excluded Paths + +Exclude conflicting files during installation: + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/lib64/libnvidia-egl-gbm.so.1.1.2 \ + --excludepath=/usr/share/egl/egl_external_platform.d/15_nvidia_gbm.json \ + egl-gbm-1.1.2^20240919gitb24587d-3.fc39.x86_64.rpm +``` + +**Explanation:** + +- The `--excludepath` option skips installing files that conflict with existing files. +- Adjust the paths based on the error messages you receive. + +### 5. Install `egl-wayland` with Excluded Paths + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/share/egl/egl_external_platform.d/10_nvidia_wayland.json \ + egl-wayland-1.1.17^20241118giteeb29e1-5.fc39.x86_64.rpm +``` + +### 6. Install `nvidia-driver-libs` with Excluded Paths + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/share/glvnd/egl_vendor.d/10_nvidia.json \ + --excludepath=/usr/share/nvidia/nvoptix.bin \ + nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +**Note:** + +- Replace the paths with the ones causing conflicts in your installation if they differ. +- The `--verbose` and `--hash` options provide detailed output during installation. + +## Finalizing the Installation of `nvidia-driver-libs` + +After manually installing the dependencies, run: + +```bash +sudo dnf install nvidia-driver-libs +``` + +You should receive a message indicating the package is already installed: + +``` +Package nvidia-driver-libs-3:560.35.05-1.fc39.x86_64 is already installed. +Dependencies resolved. +Nothing to do. +Complete! +``` + +## Installing the CUDA Meta-Package + +Now that the driver libraries are installed, proceed to install CUDA: + +```bash +sudo dnf install cuda +``` + +This installs the CUDA toolkit and associated packages. + +## Configuring the Environment + +To use CUDA, add its binary directory to your system's `PATH`. + +1. **Create a Profile Script:** + + ```bash + sudo sh -c 'echo "export PATH=\$PATH:/usr/local/cuda/bin" >> /etc/profile.d/cuda.sh' + ``` + + **Explanation:** + + - We add to `/etc/profile.d/` as the `/etc/` folder is unique to this particular container, and is not shared with other containers or the host system. + - The backslash `\` before `$PATH` ensures the variable is correctly written into the script. + +2. **Make the Script Executable:** + + ```bash + sudo chmod +x /etc/profile.d/cuda.sh + ``` + +3. **Source the Script to Update Your Environment:** + + ```bash + source /etc/profile.d/cuda.sh + ``` + + **Note:** This command updates your current shell session with the new `PATH`. The `/etc/profile.d/cuda.sh` script ensures that the CUDA binaries are available in your `PATH` for all future sessions. + +## Verifying the Installation + +To confirm that CUDA is correctly installed and configured, check the version of the NVIDIA CUDA Compiler (`nvcc`): + +```bash +nvcc --version +``` + +You should see output similar to: + +``` +nvcc: NVIDIA (R) Cuda compiler driver +Copyright (c) 2005-2024 NVIDIA Corporation +Built on Tue_Oct_29_23:50:19_PDT_2024 +Cuda compilation tools, release 12.6, V12.6.85 +Build cuda_12.6.r12.6/compiler.35059454_0 +``` + +This output confirms that the CUDA compiler is accessible and indicates the installed version. + +## Conclusion + +You have successfully set up CUDA on Fedora within a toolbox environment using the Fedora 39 CUDA repository. By manually resolving package conflicts and configuring the environment, you can develop CUDA applications without affecting your host system. + +## Troubleshooting + +- **Installation Failures:** + - If you encounter errors during installation, carefully read the error messages. They often indicate conflicting files or missing dependencies. + - Use the `--excludepath` option with `rpm` to exclude conflicting files during manual installations. + +- **Driver Conflicts:** + - Since the host system may already have NVIDIA drivers installed, conflicts can arise. Using the toolbox environment helps isolate these issues. + +- **Environment Variables Not Set:** + - If `nvcc` is not found after installation, ensure that `/usr/local/cuda/bin` is in your `PATH`. + - Run `echo $PATH` to check if the path is included. + - Re-source the profile script or open a new terminal session. + +## Additional Notes + +- **Updating CUDA in the Future:** + - Keep an eye on the official NVIDIA repositories for updates to your Fedora version. + - When an updated repository becomes available, adjust your `dnf` configuration accordingly. + +- **Building `llama.cpp`:** + - With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support. + - Ensure that any CUDA-specific build flags or paths are correctly set in your build configuration. + +- **Using the Toolbox Environment:** + - The toolbox environment is isolated from your host system, which helps prevent conflicts. + - Remember that system files and configurations inside the toolbox are separate from the host. By default the home directory of the user is shared between the host and the toolbox. + +--- + +**Disclaimer:** Manually installing and modifying system packages can lead to instability of the container. The above steps are provided as a guideline and may need adjustments based on your specific system configuration. Always back up important data before making significant system changes, especially as your home folder is writable and shared with he toolbox. + +**Acknowledgments:** Special thanks to the Fedora community and NVIDIA documentation for providing resources that assisted in creating this guide. + +## References + +- [Fedora Toolbox Documentation](https://docs.fedoraproject.org/en-US/fedora-silverblue/toolbox/) +- [NVIDIA CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) +- [Podman Documentation](https://podman.io/get-started) + +--- diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 04c5ccbbe..8fcd70811 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -28,7 +28,7 @@ The required steps to implement for an HF model are: ```python @Model.register("MyModelForCausalLM") class MyModel(Model): - model_arch = gguf.MODEL_ARCH.GROK + model_arch = gguf.MODEL_ARCH.MYMODEL ``` 2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py) @@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi - `Model#set_vocab` - `Model#write_tensors` -NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights. +NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights. ### 2. Define the model architecture in `llama.cpp` The model params and tensors layout must be defined in `llama.cpp`: 1. Define a new `llm_arch` 2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non standard metadata in `llm_load_hparams` +3. Add any non-standard metadata in `llm_load_hparams` 4. Create the tensors for inference in `llm_load_tensors` 5. If the model has a RoPE operation, add the rope type in `llama_rope_type` @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. -Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`. +Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. -When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. +Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/). diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index a3b21ad6b..dd75ff9f1 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -38,7 +38,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); @@ -194,7 +194,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index e2e01f2d5..d34b03099 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n" , __func__); @@ -120,7 +120,7 @@ int main(int argc, char ** argv) { } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = llama_token_bos(model); } @@ -236,7 +236,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 988a584c9..1256abb17 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -1,4 +1,6 @@ #include "ggml.h" +#include "gguf.h" + #include "llama.h" #include "common.h" #include "log.h" @@ -434,12 +436,12 @@ static void print_matrix(struct ggml_tensor * probs) { } } -struct llama_file { +struct my_llama_file { // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; - llama_file(const char * fname, const char * mode) { + my_llama_file(const char * fname, const char * mode) { fp = std::fopen(fname, mode); if (fp == NULL) { size = 0; @@ -500,7 +502,7 @@ struct llama_file { return std::string(chars.data(), len); } - ~llama_file() { + ~my_llama_file() { if (fp) { std::fclose(fp); } @@ -508,7 +510,7 @@ struct llama_file { }; static bool is_ggml_file(const char * filename) { - llama_file file(filename, "rb"); + my_llama_file file(filename, "rb"); if (file.size < 4) { return false; } @@ -576,7 +578,7 @@ static void load_vocab(const char * filename, const Config * config, struct my_l } else { // assume llama2.c vocabulary LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename); - llama_file file(filename, "rb"); + my_llama_file file(filename, "rb"); if (!file.fp) { die_fmt("%s: %s", strerror(errno), filename); } @@ -689,8 +691,8 @@ static void save_as_llama_model( gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID); - gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, -1); - gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, -1); + gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL); + gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL); gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx); gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index d1731bba6..e899c1078 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -1,7 +1,9 @@ +#include "ggml.h" +#include "gguf.h" + #include "arg.h" #include "common.h" #include "llama.h" -#include "ggml.h" #include "pca.hpp" #include "mean.hpp" @@ -415,12 +417,13 @@ int main(int argc, char ** argv) { // load the model to get hparams common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); // int n_ctx = llama_n_ctx(ctx); int n_layers = llama_n_layer(model); int n_embd = llama_n_embd(model); + // get model hint param (a.k.a model arch name) char model_hint[128]; llama_model_meta_val_str(model, "general.architecture", model_hint, 128); @@ -474,8 +477,6 @@ int main(int argc, char ** argv) { // done with the model, we can now free it to make gain some memory printf("Done evaluate prompts, unload model...\n"); - llama_free(ctx); - llama_free_model(model); bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3f18fc6a7..27f75cb77 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -97,8 +97,9 @@ int main(int argc, char ** argv) { // load the model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); return 1; @@ -316,8 +317,6 @@ int main(int argc, char ** argv) { // clean up llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index c08e3e5f6..2111c3cda 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -162,8 +162,9 @@ int main(int argc, char ** argv) { // init common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); return 1; @@ -184,9 +185,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); return 0; diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 058b5cc86..d5dcd20a0 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -1,7 +1,9 @@ -#include "arg.h" -#include "common.h" #include "ggml.h" #include "ggml-alloc.h" +#include "gguf.h" + +#include "arg.h" +#include "common.h" #include #include diff --git a/examples/gguf-hash/gguf-hash.cpp b/examples/gguf-hash/gguf-hash.cpp index e96c75117..9523ec122 100644 --- a/examples/gguf-hash/gguf-hash.cpp +++ b/examples/gguf-hash/gguf-hash.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "gguf.h" #include /* abort() */ #include diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index 75f63f938..ef3ceb686 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -1,18 +1,19 @@ +#include "ggml.h" +#include "gguf.h" #include "llama.h" #include "common.h" #include -#include +#include +#include +#include #include +#include +#include #include #include #include -#include -#include -#include -#include - #if defined(_WIN32) #include #ifndef PATH_MAX @@ -297,7 +298,7 @@ struct split_strategy { total_size += ggml_nbytes(t); } total_size = total_size / 1000 / 1000; // convert to megabytes - printf("split %05d: n_tensors = %d, total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size); + printf("split %05d: n_tensors = %" PRIi64 ", total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size); i_split++; } } diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index 7498f85ef..f31989c8c 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -1,10 +1,9 @@ #include "ggml.h" +#include "gguf.h" #include -#include #include #include -#include #include #undef MIN @@ -135,9 +134,10 @@ static bool gguf_ex_read_0(const std::string & fname) { for (int i = 0; i < n_tensors; ++i) { const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); - printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); } } @@ -182,9 +182,10 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { for (int i = 0; i < n_tensors; ++i) { const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); - printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); } } @@ -199,7 +200,8 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - printf("%s: tensor[%d]: n_dims = %d, name = %s, data = %p\n", __func__, i, ggml_n_dims(cur), cur->name, cur->data); + printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n", + __func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data); // print first 10 elements const float * data = (const float *) cur->data; @@ -215,7 +217,7 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { const float * data = (const float *) cur->data; for (int j = 0; j < ggml_nelements(cur); ++j) { if (data[j] != 100 + i) { - fprintf(stderr, "%s: tensor[%d]: data[%d] = %f\n", __func__, i, j, data[j]); + fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i)); gguf_free(ctx); return false; } @@ -245,6 +247,8 @@ int main(int argc, char ** argv) { check_data = false; } + srand(123456); + const std::string fname(argv[1]); const std::string mode (argv[2]); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 18a945b33..4d2db5624 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -165,7 +165,7 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); @@ -219,7 +219,7 @@ int main(int argc, char * argv[]) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 45206f4a7..588114ecd 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -430,9 +430,10 @@ static void process_logits( static bool compute_imatrix(llama_context * ctx, const common_params & params) { const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const int n_ctx = llama_n_ctx(ctx); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -618,8 +619,9 @@ int main(int argc, char ** argv) { // init common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); return 1; @@ -655,9 +657,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); return 0; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index ef7008957..d460be314 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -131,8 +131,8 @@ int main(int argc, char ** argv) { LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); common_init_result llama_init = common_init_from_params(params); - model = llama_init.model; - ctx = llama_init.context; + model = llama_init.model.get(); + ctx = llama_init.context.get(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); @@ -581,9 +581,6 @@ int main(int argc, char ** argv) { LOG("\n"); common_perf_print(ctx, smpl); - llama_free(ctx); - llama_free_model(model); - common_sampler_free(smpl); llama_backend_free(); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2338ad106..2a0916766 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1526,10 +1526,10 @@ int main(int argc, char ** argv) { // keep the same model between tests when possible if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) { if (lmodel) { - llama_free_model(lmodel); + llama_model_free(lmodel); } - lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams()); + lmodel = llama_model_load_from_file(inst.model.c_str(), inst.to_llama_mparams()); if (lmodel == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str()); return 1; @@ -1540,7 +1540,7 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); - llama_free_model(lmodel); + llama_model_free(lmodel); return 1; } @@ -1626,7 +1626,7 @@ int main(int argc, char ** argv) { ggml_threadpool_free_fn(threadpool); } - llama_free_model(lmodel); + llama_model_free(lmodel); if (p) { p->print_footer(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index b3858ddfb..66ec2aeeb 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -305,7 +305,9 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast(batch_pointer)); + //llama_batch_free(*reinterpret_cast(batch_pointer)); + const auto batch = reinterpret_cast(batch_pointer); + delete batch; } extern "C" diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 3cd0d2fa8..7a8a3156b 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -7,6 +7,7 @@ #include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "gguf.h" //#ifdef GGML_USE_CUDA //#include "ggml-cuda.h" @@ -262,7 +263,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { @@ -2734,7 +2735,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i total_size_org += orig_size; total_size_new += new_size; gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data); fout.write((const char *)new_data, new_size); size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; for (size_t j = 0; j < pad; ++j) { diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 2691c6e6b..27215a42e 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -221,7 +221,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -265,7 +265,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } @@ -323,7 +323,7 @@ int main(int argc, char ** argv) { } } - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index e9cbb51ed..2342bdd09 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -31,7 +31,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -75,7 +75,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index e86a60280..f3e5d66e2 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -310,7 +310,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -354,7 +354,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } @@ -575,7 +575,7 @@ int main(int argc, char ** argv) { } } - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 8d0ef8b3d..e016618e3 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -58,8 +58,8 @@ int main(int argc, char ** argv) { // load the target model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); // Tokenize the prompt std::vector inp; @@ -474,9 +474,6 @@ int main(int argc, char ** argv) { llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 7ced0aa97..3da45ed9e 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -1,14 +1,9 @@ #include "arg.h" #include "common.h" #include "ngram-cache.h" -#include "ggml.h" #include "llama.h" -#include -#include -#include #include -#include #include int main(int argc, char ** argv){ @@ -25,16 +20,16 @@ int main(int argc, char ** argv){ // load the model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; + GGML_ASSERT(model != nullptr); // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx.get(), params.prompt, true, true); fprintf(stderr, "%s: tokenization done\n", __func__); - common_ngram_cache ngram_cache; common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index dff07c075..fcb289abe 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -30,12 +30,11 @@ int main(int argc, char ** argv){ // load the model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_context_ptr & ctx = llama_init.context; // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx.get(), params.prompt, true, true); common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; @@ -66,7 +65,7 @@ int main(int argc, char ** argv){ } const int n_input = inp.size(); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx.get()); int n_drafted = 0; int n_accept = 0; @@ -150,9 +149,6 @@ int main(int argc, char ** argv){ LOG_INF("n_accept = %d\n", n_accept); LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 4d92bb238..0d68b80b9 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -33,8 +33,8 @@ int main(int argc, char ** argv){ // load the model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); // tokenize the prompt std::vector inp; @@ -243,9 +243,6 @@ int main(int argc, char ** argv){ llama_batch_free(batch_tgt); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d0c28f317..aaee47e32 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -145,18 +145,18 @@ int main(int argc, char ** argv) { llama_context * ctx = nullptr; common_sampler * smpl = nullptr; - std::vector chat_msgs; - g_model = &model; g_ctx = &ctx; g_smpl = &smpl; + std::vector chat_msgs; + // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); common_init_result llama_init = common_init_from_params(params); - model = llama_init.model; - ctx = llama_init.context; + model = llama_init.model.get(); + ctx = llama_init.context.get(); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n", __func__); @@ -494,7 +494,7 @@ int main(int argc, char ** argv) { } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = llama_token_bos(model); } @@ -831,7 +831,7 @@ int main(int argc, char ** argv) { // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { llama_token eot = llama_token_eot(model); - embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot); + embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot); need_insert_eot = false; } @@ -889,9 +889,6 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); ggml_threadpool_free_fn(threadpool); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index fd2b1c011..d48f51975 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -132,8 +132,8 @@ int main(int argc, char ** argv) { // load the target model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -416,9 +416,6 @@ int main(int argc, char ** argv) { llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 09bba708f..ea91f376c 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -63,7 +63,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 64a84607c..6bdc57f8e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1987,8 +1987,9 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); return 1; @@ -2023,9 +2024,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); return 0; diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 912caf346..9bfbb8862 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#include "common.h" #include "ggml.h" #include "llama.h" -#include "llama-impl.h" +#include "llama-context.h" +#include "common.h" #include #include @@ -9,11 +9,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include @@ -311,7 +309,7 @@ int main(int argc, char ** argv) { auto mparams = llama_model_default_params(); mparams.use_mlock = false; - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); @@ -325,18 +323,18 @@ int main(int argc, char ** argv) { if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } - const auto &tensors = llama_internal_get_tensor_map(ctx); + const auto & tensors = llama_internal_get_tensor_map(ctx); // check layer tensors int included_layers = 0; int64_t max_nelements = 0; bool is_f16 = false; - for (const auto& kv_tensor : tensors) { + for (const auto & kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; } @@ -349,7 +347,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: Quantization should be tested with a float model, " "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 1; } included_layers++; @@ -371,8 +369,8 @@ int main(int argc, char ** argv) { if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) { continue; } - const auto * qfns = ggml_get_type_traits(type); - const auto * qfns_cpu = ggml_get_type_traits_cpu(type); + const auto * qfns = ggml_get_type_traits(type); + const auto * qfns_cpu = ggml_get_type_traits_cpu(type); if (qfns_cpu->from_float && qfns->to_float) { if (params.verbose) { printf("testing %s ...\n", ggml_type_name(type)); @@ -382,7 +380,7 @@ int main(int argc, char ** argv) { error_stats global_stats {}; - for (const auto& kv_tensor : tensors) { + for (const auto & kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; } @@ -411,7 +409,7 @@ int main(int argc, char ** argv) { llama_free(ctx); - llama_free_model(model); + llama_model_free(model); // report timing { const int64_t t_main_end_us = ggml_time_us(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index a5c6fe7e5..f534b5eff 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -151,8 +151,8 @@ int main(int argc, char ** argv) { // load the model common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); @@ -298,7 +298,5 @@ int main(int argc, char ** argv) { // clean up llama_batch_free(query_batch); - llama_free(ctx); - llama_free_model(model); llama_backend_free(); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index f89d041c4..61420e441 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1,5 +1,6 @@ #if defined(_WIN32) # include +# include #else # include # include @@ -10,6 +11,8 @@ # include #endif +#include + #include #include #include @@ -24,6 +27,13 @@ #include "json.hpp" #include "llama-cpp.h" +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) +[[noreturn]] static void sigint_handler(int) { + printf("\n"); + exit(0); // not ideal, but it's the only way to guarantee exit in all cases +} +#endif + GGML_ATTRIBUTE_FORMAT(1, 2) static std::string fmt(const char * fmt, ...) { va_list ap; @@ -82,6 +92,7 @@ class Opt { } ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; + ctx_params.n_ctx = ctx_params.n_batch; model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; temperature = temperature >= 0 ? temperature : temperature_default; @@ -253,7 +264,7 @@ class File { return 1; } - OVERLAPPED overlapped = { 0 }; + OVERLAPPED overlapped = {}; if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, &overlapped)) { fd = -1; @@ -277,7 +288,7 @@ class File { if (fd >= 0) { # ifdef _WIN32 if (hFile != INVALID_HANDLE_VALUE) { - OVERLAPPED overlapped = { 0 }; + OVERLAPPED overlapped = {}; UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped); } # else @@ -293,7 +304,7 @@ class File { private: int fd = -1; # ifdef _WIN32 - HANDLE hFile; + HANDLE hFile = nullptr; # endif }; @@ -464,7 +475,7 @@ class HttpClient { return (now_downloaded_plus_file_size * 100) / total_to_download; } - static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", percentage); } + static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast(percentage)); } static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { const auto now = std::chrono::steady_clock::now(); @@ -663,7 +674,7 @@ class LlamaData { "\r%*s" "\rLoading model", get_terminal_width(), " "); - llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params)); + llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); if (!model) { printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); } @@ -799,7 +810,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str static int read_user_input(std::string & user) { std::getline(std::cin, user); - return user.empty(); // Should have data in happy path + if (std::cin.eof()) { + printf("\n"); + return 1; + } + + if (user == "/bye") { + return 1; + } + + if (user.empty()) { + return 2; + } + + return 0; // Should have data in happy path } // Function to generate a response based on the prompt @@ -866,7 +890,25 @@ static bool is_stdout_a_terminal() { #endif } -// Function to tokenize the prompt +// Function to handle user input +static int get_user_input(std::string & user_input, const std::string & user) { + while (true) { + const int ret = handle_user_input(user_input, user); + if (ret == 1) { + return 1; + } + + if (ret == 2) { + continue; + } + + break; + } + + return 0; +} + +// Main chat loop function static int chat_loop(LlamaData & llama_data, const std::string & user) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); @@ -874,7 +916,8 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { while (true) { // Get user input std::string user_input; - while (handle_user_input(user_input, user)) { + if (get_user_input(user_input, user) == 1) { + return 0; } add_message("user", user.empty() ? user_input : user, llama_data); @@ -915,7 +958,23 @@ static std::string read_pipe_data() { return result.str(); } +static void ctrl_c_handling() { +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined(_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif +} + int main(int argc, const char ** argv) { + ctrl_c_handling(); Opt opt; const int ret = opt.init(argc, argv); if (ret == 2) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 2f0cf9baa..cd03661cf 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -30,8 +30,8 @@ int main(int argc, char ** argv) { // init common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); if (model == nullptr || ctx == nullptr) { fprintf(stderr, "%s : failed to init\n", __func__); @@ -89,8 +89,6 @@ int main(int argc, char ** argv) { if (llama_decode(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); return 1; } n_past += 1; @@ -98,11 +96,8 @@ int main(int argc, char ** argv) { printf("\n\n"); - // free old context - llama_free(ctx); - // make new context - auto * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_sampler * smpl2 = llama_sampler_chain_init(sparams); @@ -123,8 +118,6 @@ int main(int argc, char ** argv) { if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); - llama_free(ctx2); - llama_free_model(model); return 1; } @@ -148,8 +141,6 @@ int main(int argc, char ** argv) { if (llama_decode(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_batch_free(batch); - llama_free(ctx2); - llama_free_model(model); return 1; } n_past += 1; @@ -157,15 +148,13 @@ int main(int argc, char ** argv) { printf("\n\n"); - llama_free(ctx2); - if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } // make new context - auto * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); @@ -186,8 +175,6 @@ int main(int argc, char ** argv) { if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); - llama_free(ctx3); - llama_free_model(model); return 1; } @@ -204,8 +191,6 @@ int main(int argc, char ** argv) { const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0); if (ncopy != seq_store.size()) { fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); - llama_free(ctx3); - llama_free_model(model); return 1; } fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); @@ -218,8 +203,6 @@ int main(int argc, char ** argv) { const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1); if (nset != seq_store.size()) { fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); - llama_free(ctx3); - llama_free_model(model); return 1; } fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); @@ -239,8 +222,6 @@ int main(int argc, char ** argv) { if (llama_decode(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_batch_free(batch); - llama_free(ctx3); - llama_free_model(model); return 1; } n_past += 1; @@ -253,8 +234,6 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl3); llama_batch_free(batch); - llama_free(ctx3); - llama_free_model(model); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/server/README.md b/examples/server/README.md index 07436057a..1f0a27d96 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -45,10 +45,7 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | -| `-p, --prompt PROMPT` | prompt to start generation with | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | -| `-f, --file FNAME` | a file containing the prompt (default: none) | -| `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | | `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) | @@ -345,7 +342,7 @@ node index.js > [!IMPORTANT] > -> This endpoint is **not** OAI-compatible +> This endpoint is **not** OAI-compatible. For OAI-compatible client, use `/v1/completions` instead. *Options:* @@ -452,6 +449,8 @@ These words will not be included in the completion, so make sure to add them to `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name. +`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. @@ -523,6 +522,7 @@ These words will not be included in the completion, so make sure to add them to - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) + ### POST `/tokenize`: Tokenize a given text *Options:* @@ -574,6 +574,10 @@ With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k ### POST `/embedding`: Generate embedding of a given text +> [!IMPORTANT] +> +> This endpoint is **not** OAI-compatible. For OAI-compatible client, use `/v1/embeddings` instead. + The same as [the embedding example](../embedding) does. *Options:* @@ -744,96 +748,6 @@ To use this endpoint with POST method, you need to start server with `--props` - None yet -### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API - -Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. - -*Options:* - -See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported. - -The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers. - -*Examples:* - -You can use either Python `openai` library with appropriate checkpoints: - -```python -import openai - -client = openai.OpenAI( - base_url="http://localhost:8080/v1", # "http://:port" - api_key = "sk-no-key-required" -) - -completion = client.chat.completions.create( -model="gpt-3.5-turbo", -messages=[ - {"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."}, - {"role": "user", "content": "Write a limerick about python exceptions"} -] -) - -print(completion.choices[0].message) -``` - -... or raw HTTP requests: - -```shell -curl http://localhost:8080/v1/chat/completions \ --H "Content-Type: application/json" \ --H "Authorization: Bearer no-key" \ --d '{ -"model": "gpt-3.5-turbo", -"messages": [ -{ - "role": "system", - "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests." -}, -{ - "role": "user", - "content": "Write a limerick about python exceptions" -} -] -}' -``` - -### POST `/v1/embeddings`: OpenAI-compatible embeddings API - -This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. - -*Options:* - -See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). - -*Examples:* - -- input as string - - ```shell - curl http://localhost:8080/v1/embeddings \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer no-key" \ - -d '{ - "input": "hello", - "model":"GPT-4", - "encoding_format": "float" - }' - ``` - -- `input` as string array - - ```shell - curl http://localhost:8080/v1/embeddings \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer no-key" \ - -d '{ - "input": ["hello", "world"], - "model":"GPT-4", - "encoding_format": "float" - }' - ``` - ### POST `/embeddings`: non-OpenAI-compatible embeddings API This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm. @@ -1030,6 +944,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` +Please note that this value will be overwritten by the `lora` field for each request. + If an adapter is disabled, the scale will be set to 0. **Response format** @@ -1051,6 +967,8 @@ If an adapter is disabled, the scale will be set to 0. ### POST `/lora-adapters`: Set list of LoRA adapters +This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request. + To disable an adapter, either remove it from the list below, or set scale to 0. **Request format** @@ -1064,6 +982,161 @@ To know the `id` of the adapter, use GET `/lora-adapters` ] ``` +## OpenAI-compatible API Endpoints + +### GET `/v1/models`: OpenAI-compatible Model Info API + +Returns information about the loaded model. See [OpenAI Models API documentation](https://platform.openai.com/docs/api-reference/models). + +The returned list always has one single element. + +By default, model `id` field is the path to model file, specified via `-m`. You can set a custom value for model `id` field via `--alias` argument. For example, `--alias gpt-4o-mini`. + +Example: + +```json +{ + "object": "list", + "data": [ + { + "id": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", + "object": "model", + "created": 1735142223, + "owned_by": "llamacpp", + "meta": { + "vocab_type": 2, + "n_vocab": 128256, + "n_ctx_train": 131072, + "n_embd": 4096, + "n_params": 8030261312, + "size": 4912898304 + } + } + ] +} +``` + +### POST `/v1/completions`: OpenAI-compatible Completions API + +Given an input `prompt`, it returns the predicted completion. Streaming mode is also supported. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. + +*Options:* + +See [OpenAI Completions API documentation](https://platform.openai.com/docs/api-reference/completions). + +llama.cpp `/completion`-specific features such as `mirostat` are supported. + +*Examples:* + +Example usage with `openai` python library: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "sk-no-key-required" +) + +completion = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8 +) + +print(completion.choices[0].text) +``` + +### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API + +Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. + +*Options:* + +See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported. + +The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers. + +*Examples:* + +You can use either Python `openai` library with appropriate checkpoints: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "sk-no-key-required" +) + +completion = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."}, + {"role": "user", "content": "Write a limerick about python exceptions"} + ] +) + +print(completion.choices[0].message) +``` + +... or raw HTTP requests: + +```shell +curl http://localhost:8080/v1/chat/completions \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer no-key" \ +-d '{ +"model": "gpt-3.5-turbo", +"messages": [ +{ + "role": "system", + "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests." +}, +{ + "role": "user", + "content": "Write a limerick about python exceptions" +} +] +}' +``` + +### POST `/v1/embeddings`: OpenAI-compatible embeddings API + +This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. + +*Options:* + +See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). + +*Examples:* + +- input as string + + ```shell + curl http://localhost:8080/v1/embeddings \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer no-key" \ + -d '{ + "input": "hello", + "model":"GPT-4", + "encoding_format": "float" + }' + ``` + +- `input` as string array + + ```shell + curl http://localhost:8080/v1/embeddings \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer no-key" \ + -d '{ + "input": ["hello", "world"], + "model":"GPT-4", + "encoding_format": "float" + }' + ``` + ## More examples ### Interactive mode diff --git a/examples/server/bench/README.md b/examples/server/bench/README.md index 353368e13..9549795ec 100644 --- a/examples/server/bench/README.md +++ b/examples/server/bench/README.md @@ -6,10 +6,10 @@ Benchmark is using [k6](https://k6.io/). SSE is not supported by default in k6, you have to build k6 with the [xk6-sse](https://github.com/phymbert/xk6-sse) extension. -Example: +Example (assuming golang >= 1.21 is installed): ```shell go install go.k6.io/xk6/cmd/xk6@latest -xk6 build master \ +$GOPATH/bin/xk6 build master \ --with github.com/phymbert/xk6-sse ``` @@ -33,7 +33,7 @@ The server must answer OAI Chat completion requests on `http://localhost:8080/v1 Example: ```shell -server --host localhost --port 8080 \ +llama-server --host localhost --port 8080 \ --model ggml-model-q4_0.gguf \ --cont-batching \ --metrics \ diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index a9ed747f5..5cc6f92ab 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -189,12 +189,12 @@ xychart-beta "pp": { "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2), "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2), + "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0, }, "tg": { "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2), "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2), + "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0, }, } with open("results.github.env", 'a') as github_env: @@ -214,11 +214,14 @@ def start_benchmark(args): k6_args = [ 'run', args.scenario, '--no-color', + '--no-connection-reuse', + '--no-vu-connection-reuse', ] k6_args.extend(['--duration', args.duration]) k6_args.extend(['--iterations', args.n_prompts]) k6_args.extend(['--vus', args.parallel]) k6_args.extend(['--summary-export', 'k6-results.json']) + k6_args.extend(['--out', 'csv=k6-results.csv']) args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} " args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]]) print(f"bench: starting k6 with: {args}") @@ -231,7 +234,7 @@ def start_server(args): server_process = start_server_background(args) attempts = 0 - max_attempts = 20 + max_attempts = 600 if 'GITHUB_ACTIONS' in os.environ: max_attempts *= 2 @@ -242,7 +245,15 @@ def start_server(args): print(f"bench: waiting for server to start ...") time.sleep(0.5) - print("bench: server started.") + attempts = 0 + while not is_server_ready(args.host, args.port): + attempts += 1 + if attempts > max_attempts: + assert False, "server not ready" + print(f"bench: waiting for server to be ready ...") + time.sleep(0.5) + + print("bench: server started and ready.") return server_process @@ -255,11 +266,6 @@ def start_server_background(args): '--host', args.host, '--port', args.port, ] - model_file = args.model_path_prefix + os.path.sep + args.hf_file - model_dir = os.path.dirname(model_file) - if not os.path.exists(model_dir): - os.makedirs(model_dir) - server_args.extend(['--model', model_file]) server_args.extend(['--hf-repo', args.hf_repo]) server_args.extend(['--hf-file', args.hf_file]) server_args.extend(['--n-gpu-layers', args.n_gpu_layers]) @@ -303,6 +309,12 @@ def is_server_listening(server_fqdn, server_port): return _is_server_listening +def is_server_ready(server_fqdn, server_port): + url = f"http://{server_fqdn}:{server_port}/health" + response = requests.get(url) + return response.status_code == 200 + + def escape_metric_name(metric_name): return re.sub('[^A-Z0-9]', '_', metric_name.upper()) diff --git a/examples/server/bench/script.js b/examples/server/bench/script.js index bdf4f5abc..2772bee5e 100644 --- a/examples/server/bench/script.js +++ b/examples/server/bench/script.js @@ -56,6 +56,7 @@ const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens') const llamacpp_tokens_second = new Trend('llamacpp_tokens_second') const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second') +const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second') const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter') const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter') @@ -89,6 +90,9 @@ export default function () { ], "model": model, "stream": true, + "stream_options": { + "include_usage": true, // False to be supported in llama.cpp server + }, "seed": 42, "max_tokens": max_tokens, "stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS @@ -105,12 +109,20 @@ export default function () { client.on('event', function (event) { if (promptEvalEndTime == null) { promptEvalEndTime = new Date() + llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3) + } + + if (event.data === '[DONE]' || event.data === '') { + return } let chunk = JSON.parse(event.data) - let choice = chunk.choices[0] - if (choice.finish_reason) { - finish_reason = choice.finish_reason + + if (chunk.choices && chunk.choices.length > 0) { + let choice = chunk.choices[0] + if (choice.finish_reason) { + finish_reason = choice.finish_reason + } } if (chunk.usage) { diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 36f9c9fe9..3640a7a6c 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3558ddb7c..127323e77 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -67,6 +67,13 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { ERROR_TYPE_INVALID_REQUEST, @@ -91,6 +98,8 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + std::vector lora; + std::vector antiprompt; std::vector response_fields; bool timings_per_token = false; @@ -101,11 +110,10 @@ struct slot_params { struct common_params_speculative speculative; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; json to_json() const { std::vector samplers; @@ -114,6 +122,11 @@ struct slot_params { samplers.emplace_back(common_sampler_type_to_str(sampler)); } + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + return json { {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, @@ -154,6 +167,7 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, }; } }; @@ -183,6 +197,9 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -245,6 +262,16 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { @@ -529,11 +556,10 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; // TODO: support oaicompat for non-chat - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; virtual int get_index() override { return index; @@ -544,9 +570,16 @@ struct server_task_result_cmpl_final : server_task_result { } virtual json to_json() override { - return oaicompat - ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat()) - : to_json_non_oaicompat(); + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } } json to_json_non_oaicompat() { @@ -574,6 +607,50 @@ struct server_task_result_cmpl_final : server_task_result { return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + json to_json_oaicompat_chat() { std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -671,11 +748,10 @@ struct server_task_result_cmpl_partial : server_task_result { result_timings timings; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; // TODO: support oaicompat for non-chat - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; virtual int get_index() override { return index; @@ -686,7 +762,16 @@ struct server_task_result_cmpl_partial : server_task_result { } virtual json to_json() override { - return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } } json to_json_non_oaicompat() { @@ -711,6 +796,41 @@ struct server_task_result_cmpl_partial : server_task_result { } json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { bool first = n_decoded == 0; std::time_t t = std::time(0); json choices; @@ -789,14 +909,16 @@ struct server_task_result_embd : server_task_result { int32_t n_tokens; // OAI-compat fields - bool oaicompat = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; virtual int get_index() override { return index; } virtual json to_json() override { - return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); } json to_json_non_oaicompat() { @@ -1009,6 +1131,8 @@ struct server_slot { common_speculative * spec = nullptr; + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -1090,6 +1214,11 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -1497,11 +1626,15 @@ struct server_response { struct server_context { common_params params_base; + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector loras; llama_model * model_dft = nullptr; + llama_context_params cparams_dft; llama_batch batch = {}; @@ -1525,21 +1658,6 @@ struct server_context { float slot_prompt_similarity = 0.0f; ~server_context() { - if (ctx) { - llama_free(ctx); - ctx = nullptr; - } - - if (model) { - llama_free_model(model); - model = nullptr; - } - - if (model_dft) { - llama_free_model(model_dft); - model_dft = nullptr; - } - // Clear any sampling context for (server_slot & slot : slots) { common_sampler_free(slot.smpl); @@ -1562,11 +1680,10 @@ struct server_context { params_base = params; - common_init_result llama_init = common_init_from_params(params_base); + llama_init = common_init_from_params(params_base); - model = llama_init.model; - ctx = llama_init.context; - loras = llama_init.lora_adapters; + model = llama_init.model.get(); + ctx = llama_init.context.get(); if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); @@ -1589,25 +1706,22 @@ struct server_context { params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; - common_init_result llama_init_dft = common_init_from_params(params_dft); + llama_init_dft = common_init_from_params(params_dft); - model_dft = llama_init_dft.model; + model_dft = llama_init_dft.model.get(); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); return false; } - if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); - llama_free (llama_init_dft.context); - llama_free_model(llama_init_dft.model); - return false; } - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); cparams_dft = common_context_params_to_llama(params_dft); cparams_dft.n_batch = n_ctx_dft; @@ -1615,25 +1729,15 @@ struct server_context { // force F16 KV cache for the draft model for extra performance cparams_dft.type_k = GGML_TYPE_F16; cparams_dft.type_v = GGML_TYPE_F16; - - // the context is not needed - we will create one for each slot - llama_free(llama_init_dft.context); } return true; } - bool validate_model_chat_template() const { - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res >= 0) { - llama_chat_message chat[] = {{"user", "test"}}; - std::string tmpl = std::string(model_template.data(), model_template.size()); - int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); - return chat_res > 0; - } - return false; + bool validate_builtin_chat_template() const { + llama_chat_message chat[] = {{"user", "test"}}; + int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); + return chat_res > 0; } void init() { @@ -1772,6 +1876,12 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = task.params.lora; + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { @@ -2044,7 +2154,6 @@ struct server_context { res->verbose = slot.params.verbose; res->oaicompat = slot.params.oaicompat; - res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -2085,7 +2194,6 @@ struct server_context { res->verbose = slot.params.verbose; res->stream = slot.params.stream; res->oaicompat = slot.params.oaicompat; - res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -2465,7 +2573,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SET_LORA: { - common_lora_adapters_apply(ctx, loras); + params_base.lora_adapters = std::move(task.set_lora); auto res = std::make_unique(); res->id = task.id; queue_results.send(std::move(res)); @@ -2542,12 +2650,22 @@ struct server_context { // start populating the batch for this iteration common_batch_clear(batch); + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; } + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); @@ -2566,15 +2684,18 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - // TODO: make enum - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; @@ -2735,14 +2856,6 @@ struct server_context { } } - // check that we are in the right batch_type, if not defer the slot - int slot_type = slot.is_non_causal(); - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { - continue; - } - // keep only the common part if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) @@ -2810,8 +2923,12 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_lora_adapters_apply(ctx, slot_batched->lora); + } // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { @@ -3484,7 +3601,7 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", llama_get_chat_template(ctx_server.model) }, + { "chat_template", common_get_builtin_chat_template(ctx_server.model) }, { "build_info", build_info }, }; @@ -3506,12 +3623,11 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( + const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, httplib::Response & res, - bool oaicompat = false, - bool oaicompat_chat = false) { + oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3532,13 +3648,16 @@ int main(int argc, char ** argv) { task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); - task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data); + task.params = server_task::params_from_json_cmpl( + ctx_server.model, + ctx_server.ctx, + ctx_server.params_base, + data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_chat = oaicompat_chat; - task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); @@ -3589,7 +3708,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); - if (oaicompat) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -3605,17 +3724,25 @@ int main(int argc, char ** argv) { } }; - const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic( + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, res, - /* oaicompat */ false, - /* oaicompat_chat */ false); + OAICOMPAT_TYPE_NONE); }; - const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = oaicompat_completion_params_parse(json::parse(req.body)); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + res, + OAICOMPAT_TYPE_COMPLETION); + }; + + const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { @@ -3670,7 +3797,7 @@ int main(int argc, char ** argv) { data["input_extra"] = input_extra; // default to empty array if it's not exist std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); data["prompt"] = format_infill( ctx_server.ctx, @@ -3684,22 +3811,25 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); - return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); + return handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - return handle_completions_generic( + json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, res, - /* oaicompat */ true, - /* oaicompat_chat */ true); + OAICOMPAT_TYPE_CHAT); }; const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { @@ -3772,10 +3902,10 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) { + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { const json body = json::parse(req.body); - if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3785,7 +3915,7 @@ int main(int argc, char ** argv) { if (body.count("input") != 0) { prompt = body.at("input"); } else if (body.contains("content")) { - oaicompat = false; + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); @@ -3854,16 +3984,18 @@ int main(int argc, char ** argv) { } // write JSON response - json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); res_ok(res, root); }; const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, false); + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); }; const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, true); + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); }; const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { @@ -3946,8 +4078,9 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.loras.size(); ++i) { - auto & lora = ctx_server.loras[i]; + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; result.push_back({ {"id", i}, {"path", lora.path}, @@ -3959,27 +4092,14 @@ int main(int argc, char ** argv) { }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.loras.size(); - - // clear existing value - for (auto & lora : ctx_server.loras) { - lora.scale = 0.0f; + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return; } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.loras[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - server_task task(SERVER_TASK_TYPE_SET_LORA); task.id = ctx_server.queue_tasks.get_new_id(); + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task); @@ -4033,7 +4153,7 @@ int main(int argc, char ** argv) { svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Post("/completion", handle_completions); // legacy svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions); + svr->Post("/v1/completions", handle_completions_oai); svr->Post("/chat/completions", handle_chat_completions); svr->Post("/v1/chat/completions", handle_chat_completions); svr->Post("/infill", handle_infill); @@ -4113,14 +4233,16 @@ int main(int argc, char ** argv) { // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { + if (!ctx_server.validate_builtin_chat_template()) { LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } } // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str()); + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), + common_chat_format_example(ctx_server.model, params.chat_template).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index fa3d0a2f5..5787276ab 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` +To run single test unit: + +```shell +./tests.sh unit/test_{name of test case here}.py -v -x +``` + Hint: You can compile and run test in single command, useful for local developement: ```shell diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 074b9d47b..15d024914 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -5,3 +5,4 @@ numpy~=1.26.4 openai~=1.55.3 prometheus-client~=0.20.0 requests~=2.32.3 +wget~=3.2 diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 885497081..b15dba6eb 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte def test_chat_completion_with_openai_library(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", messages=[ @@ -100,6 +100,23 @@ def test_chat_completion_with_openai_library(): assert match_regex("(Suddenly)+", res.choices[0].message.content) +def test_chat_template(): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), @@ -170,7 +187,7 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, @@ -197,7 +214,7 @@ def test_logprobs(): def test_logprobs_stream(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index a6b215944..e5e3b6077 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -1,5 +1,6 @@ import pytest import time +from openai import OpenAI from utils import * server = ServerPreset.tinyllama2() @@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream(): assert content_stream == res_non_stream.body["content"] +def test_completion_stream_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].text is not None + assert match_regex("(going|bed)+", res.choices[0].text) + + +def test_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("(going|bed)+", output_text) + + @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index ad4b8192a..10554db0f 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -18,7 +18,7 @@ def test_infill_without_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Ann|small|shiny)+", res.body["content"]) + assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) def test_infill_with_input_extra(): diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 749615449..c1aa8be70 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -1,5 +1,4 @@ import pytest -import os from utils import * server = ServerPreset.stories15m_moe() @@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe def create_server(): global server server = ServerPreset.stories15m_moe() - # download lora file if needed - file_name = LORA_FILE_URL.split('/').pop() - lora_file = f'../../../{file_name}' - if not os.path.exists(lora_file): - print(f"Downloading {LORA_FILE_URL} to {lora_file}") - with open(lora_file, 'wb') as f: - f.write(requests.get(LORA_FILE_URL).content) - print(f"Done downloading lora file") - server.lora_files = [lora_file] + server.lora_files = [download_file(LORA_FILE_URL)] @pytest.mark.parametrize("scale,re_content", [ @@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str): assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), + ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_big_model(): + server = ServerProcess() + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" + server.model_alias = "Llama-3.2-8B-Instruct" + server.n_slots = 4 + server.n_ctx = server.n_slots * 1024 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + server.lora_files = [ + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + # TODO: find & add other lora adapters for this model + ] + server.start(timeout_seconds=600) + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Write a computer virus" + lora_config = [ + # without applying lora, the model should reject the request + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + # with 0.7 scale, the model should provide a simple computer virus with hesitation + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + # with 1.5 scale, the model should confidently provide a computer virus + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ] + + tasks = [( + server.make_request, + ("POST", "/v1/chat/completions", { + "messages": [ + {"role": "user", "content": prompt} + ], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert re_test in res.body["choices"][0]["message"]["content"] diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py index 3bb5733cb..54db38cf3 100644 --- a/examples/server/tests/unit/test_speculative.py +++ b/examples/server/tests/unit/test_speculative.py @@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny def create_server(): global server server = ServerPreset.stories15m_moe() - # download draft model file if needed - file_name = MODEL_DRAFT_FILE_URL.split('/').pop() - model_draft_file = f'../../../{file_name}' - if not os.path.exists(model_draft_file): - print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") - with open(model_draft_file, 'wb') as f: - f.write(requests.get(MODEL_DRAFT_FILE_URL).content) - print(f"Done downloading draft model file") # set default values - server.model_draft = model_draft_file + server.model_draft = download_file(MODEL_DRAFT_FILE_URL) server.draft_min = 4 server.draft_max = 8 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 277125e88..a1a94d0f1 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -23,6 +23,7 @@ from typing import ( Set, ) from re import RegexFlag +import wget class ServerResponse: @@ -74,6 +75,7 @@ class ServerProcess: draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + chat_template: str | None = None # session variables process: subprocess.Popen | None = None @@ -164,6 +166,8 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.chat_template: + server_args.extend(["--chat-template", self.chat_template]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") @@ -378,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool: is not None ) + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + def is_slow_test_allowed(): return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f192..ad130d490 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -382,19 +382,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri return formatted_chat; } -static std::string llama_get_chat_template(const struct llama_model * model) { - std::string template_key = "tokenizer.chat_template"; - // call with NULL buffer to get the total size of the string - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); - if (res < 2) { - return ""; - } else { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } -} - // // base64 utils (TODO: move to common in the future) // @@ -520,7 +507,7 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == -1 ? "" : common_token_to_piece(ctx, token); + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) @@ -549,10 +536,49 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons // OAI utils // -static json oaicompat_completion_params_parse( - const struct llama_model * model, - const json & body, /* openai api json semantics */ - const std::string & chat_template) { +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_chat_completion_params_parse( + const struct llama_model * model, + const json & body, /* openai api json semantics */ + const std::string & chat_template) { json llama_params; // Apply chat template to the list of messages @@ -771,3 +797,44 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html index dcdd41079..86a79b77f 100644 --- a/examples/server/webui/index.html +++ b/examples/server/webui/index.html @@ -62,53 +62,57 @@
- +
+ +
- diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 7f4da666b..d72f5bcdd 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -69,7 +69,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); if (!model) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; @@ -194,7 +194,7 @@ int main(int argc, char ** argv) { } llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 3288c0250..f69117890 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 8ca84f7af..9070c3512 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -34,7 +34,7 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); llama_model * model_tgt = NULL; - llama_model * model_dft = NULL; + //llama_model * model_dft = NULL; llama_context * ctx_tgt = NULL; llama_context * ctx_dft = NULL; @@ -42,8 +42,8 @@ int main(int argc, char ** argv) { // load the target model common_init_result llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model; - ctx_tgt = llama_init_tgt.context; + model_tgt = llama_init_tgt.model.get(); + ctx_tgt = llama_init_tgt.context.get(); // load the draft model params.devices = params.speculative.devices; @@ -59,8 +59,8 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; common_init_result llama_init_dft = common_init_from_params(params); - model_dft = llama_init_dft.model; - ctx_dft = llama_init_dft.context; + //model_dft = llama_init_dft.model.get(); + ctx_dft = llama_init_dft.context.get(); if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { return 1; @@ -251,12 +251,6 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); common_speculative_free(spec); - llama_free(ctx_tgt); - llama_free_model(model_tgt); - - llama_free(ctx_dft); - llama_free_model(model_dft); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d4ad9751e..bc0b6813b 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -72,8 +72,9 @@ int main(int argc, char ** argv) { // load the target model common_init_result llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model; - ctx_tgt = llama_init_tgt.context; + + model_tgt = llama_init_tgt.model.get(); + ctx_tgt = llama_init_tgt.context.get(); // load the draft model params.devices = params.speculative.devices; @@ -85,8 +86,9 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; common_init_result llama_init_dft = common_init_from_params(params); - model_dft = llama_init_dft.model; - ctx_dft = llama_init_dft.context; + + model_dft = llama_init_dft.model.get(); + ctx_dft = llama_init_dft.context.get(); const bool vocab_type_tgt = llama_vocab_type(model_tgt); LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); @@ -631,12 +633,6 @@ int main(int argc, char ** argv) { llama_batch_free(batch_dft); - llama_free(ctx_tgt); - llama_free_model(model_tgt); - - llama_free(ctx_dft); - llama_free_model(model_dft); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index c97e22724..684ca054a 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -31,6 +31,7 @@ static void print_usage_information(const char * argv0) { printf(" -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); printf(" --stdin read prompt from standard input.\n"); printf(" --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); + printf(" --no-escape do not escape input (such as \\n, \\t, etc.).\n"); printf(" --no-parse-special do not parse control tokens.\n"); printf(" --log-disable disable logs. Makes stderr quiet when loading the model.\n"); printf(" --show-count print the total number of tokens.\n"); @@ -198,6 +199,7 @@ int main(int raw_argc, char ** raw_argv) { // variables where to put any arguments we see. bool printing_ids = false; bool no_bos = false; + bool no_escape = false; bool no_parse_special = false; bool disable_logging = false; bool show_token_count = false; @@ -233,6 +235,9 @@ int main(int raw_argc, char ** raw_argv) { else if (arg == "--no-bos") { no_bos = true; } + else if (arg == "--no-escape") { + no_escape = true; + } else if (arg == "--no-parse-special") { no_parse_special = true; } @@ -333,7 +338,7 @@ int main(int raw_argc, char ** raw_argv) { llama_model_params model_params = llama_model_default_params(); model_params.vocab_only = true; - llama_model * model = llama_load_model_from_file(model_path, model_params); + llama_model * model = llama_model_load_from_file(model_path, model_params); if (!model) { fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); return 1; @@ -363,6 +368,11 @@ int main(int raw_argc, char ** raw_argv) { const bool model_wants_add_bos = llama_add_bos_token(model); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; + const bool escape = !no_escape; + + if (escape) { + string_process_escapes(prompt); + } std::vector tokens; tokens = common_tokenize(model, prompt, add_bos, parse_special); @@ -398,7 +408,7 @@ int main(int raw_argc, char ** raw_argv) { } // silence valgrind llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 7f36b80f0..522f5e881 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -458,8 +458,9 @@ int main(int argc, char ** argv) { llama_context * ctx_cts = NULL; common_init_result llama_init_ttc = common_init_from_params(params); - model_ttc = llama_init_ttc.model; - ctx_ttc = llama_init_ttc.context; + + model_ttc = llama_init_ttc.model.get(); + ctx_ttc = llama_init_ttc.context.get(); // TODO: refactor in a common struct params.model = params.vocoder.model; @@ -470,8 +471,9 @@ int main(int argc, char ** argv) { params.embedding = true; common_init_result llama_init_cts = common_init_from_params(params); - model_cts = llama_init_cts.model; - ctx_cts = llama_init_cts.context; + + model_cts = llama_init_cts.model.get(); + ctx_cts = llama_init_cts.context.get(); std::vector smpl(n_parallel); for (int i = 0; i < n_parallel; ++i) { @@ -920,12 +922,6 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); - llama_free(ctx_ttc); - llama_free_model(model_ttc); - - llama_free(ctx_cts); - llama_free_model(model_cts); - llama_backend_free(); return 0; diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index e33d97482..fe8acc803 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -243,7 +243,8 @@ set(GGML_PUBLIC_HEADERS include/ggml-metal.h include/ggml-rpc.h include/ggml-sycl.h - include/ggml-vulkan.h) + include/ggml-vulkan.h + include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") #if (GGML_METAL) @@ -252,26 +253,6 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") install(TARGETS ggml LIBRARY PUBLIC_HEADER) install(TARGETS ggml-base LIBRARY) -# FIXME: this should be done in the backend cmake files -if (GGML_METAL) - # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY? - install( - FILES src/ggml-metal/ggml-metal.metal - PERMISSIONS - OWNER_READ - OWNER_WRITE - GROUP_READ - WORLD_READ - DESTINATION ${CMAKE_INSTALL_BINDIR}) - - if (NOT GGML_METAL_EMBED_LIBRARY) - install( - FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() -endif() - if (GGML_STANDALONE) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc diff --git a/ggml/include/ggml-cpp.h b/ggml/include/ggml-cpp.h index 219361af4..a12342c25 100644 --- a/ggml/include/ggml-cpp.h +++ b/ggml/include/ggml-cpp.h @@ -7,6 +7,7 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "gguf.h" #include // Smart pointers for ggml types diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c714fc8c8..8630d92c5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -241,12 +241,6 @@ #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 -#define GGUF_MAGIC "GGUF" - -#define GGUF_VERSION 3 - -#define GGUF_DEFAULT_ALIGNMENT 32 - #define GGML_UNUSED(x) (void)(x) #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) @@ -403,12 +397,6 @@ extern "C" { GGML_PREC_F32, }; - enum ggml_backend_type { - GGML_BACKEND_TYPE_CPU = 0, - GGML_BACKEND_TYPE_GPU = 10, - GGML_BACKEND_TYPE_GPU_SPLIT = 20, - }; - // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -587,8 +575,6 @@ extern "C" { struct ggml_tensor { enum ggml_type type; - GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor"); - struct ggml_backend_buffer * buffer; int64_t ne[GGML_MAX_DIMS]; // number of elements @@ -2111,132 +2097,6 @@ extern "C" { int64_t n_per_row, const float * imatrix); - // - // gguf - // - - enum gguf_type { - GGUF_TYPE_UINT8 = 0, - GGUF_TYPE_INT8 = 1, - GGUF_TYPE_UINT16 = 2, - GGUF_TYPE_INT16 = 3, - GGUF_TYPE_UINT32 = 4, - GGUF_TYPE_INT32 = 5, - GGUF_TYPE_FLOAT32 = 6, - GGUF_TYPE_BOOL = 7, - GGUF_TYPE_STRING = 8, - GGUF_TYPE_ARRAY = 9, - GGUF_TYPE_UINT64 = 10, - GGUF_TYPE_INT64 = 11, - GGUF_TYPE_FLOAT64 = 12, - GGUF_TYPE_COUNT, // marks the end of the enum - }; - - struct gguf_context; - - struct gguf_init_params { - bool no_alloc; - - // if not NULL, create a ggml_context and allocate the tensor data in it - struct ggml_context ** ctx; - }; - - GGML_API struct gguf_context * gguf_init_empty(void); - GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); - - GGML_API void gguf_free(struct gguf_context * ctx); - - GGML_API const char * gguf_type_name(enum gguf_type type); - - GGML_API int gguf_get_version (const struct gguf_context * ctx); - GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); - GGML_API void * gguf_get_data (const struct gguf_context * ctx); - - GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); - GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); - - GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); - GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); - - // will abort if the wrong type is used for the key - GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); - GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); - GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); - GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); - GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); - GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); - GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); - GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); - GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); - GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); - GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); - GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id); - GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); - GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); - - GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); - GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); - GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); - GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); - GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i); - - // removes key if it exists - GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key); - - // overrides existing values or adds a new one - GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); - GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); - GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); - GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); - GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); - GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); - GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); - GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); - GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); - GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); - GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); - GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); - GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); - GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); - - // set or add KV pairs from another context - GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); - - // manage tensor info - GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); - GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); - GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); - - // writing gguf files can be done in 2 ways: - // - // - write the entire gguf_context to a binary file in a single pass: - // - // gguf_write_to_file(ctx, fname); - // - // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: - // - // FILE * f = fopen(fname, "wb"); - // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); - // fwrite(f, ...); - // void * data = gguf_meta_get_meta_data(ctx); - // fseek(f, 0, SEEK_SET); - // fwrite(f, data, gguf_get_meta_size(ctx)); - // free(data); - // fclose(f); - // - - // write the entire context to a binary file - GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); - - // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding - GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); - GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); - #ifdef __cplusplus // restrict not standard in C++ # if defined(__GNUC__) diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h new file mode 100644 index 000000000..79ee20206 --- /dev/null +++ b/ggml/include/gguf.h @@ -0,0 +1,202 @@ +// This file contains functionality related to "GGUF" files, the binary file format used by ggml. +// GGUF files have the following structure: +// +// 1. File magic "GGUF" (4 bytes). +// 2. File version (uint32_t). +// 3. Number of ggml tensors in file (int64_t). +// 4. Number of key-value-pairs in file (int64_t). +// 5. For each KV pair: +// 1. The key (string). +// 2. The value type (gguf_type). +// 3a. If the value type is GGUF_TYPE_ARRAY: +// 1. The type of the array (gguf_type). +// 2. The number of elements in the array (uint64_t). +// 3. The binary representation of each element in the array. +// 3b. Otherwise: +// 1. The binary representation of the value. +// 6. For each ggml tensor: +// 1. The tensor name (string). +// 2. The number of dimensions of the tensor (uint32_t). +// 3. For each dimension: +// 1. The size of the tensor in the dimension (int64_t). +// 4. The tensor data type (ggml_type). +// 5. The tensor data offset in the tensor data binary blob (uint64_t). +// 7. The tensor data binary blob (optional, aligned). +// +// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator. +// All enums are stored as int32_t. +// All bool values are stored as int8_t. +// If the special key "general.alignment" (uint32_t) is defined it is used for alignment, +// otherwise GGUF_DEFAULT_ALIGNMENT is used. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" + +#include +#include + +#define GGUF_MAGIC "GGUF" +#define GGUF_VERSION 3 + +#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment" + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#ifdef __cplusplus +extern "C" { +#endif + + // types that can be stored as GGUF KV data + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + + GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id); + GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id); + GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id); + + // get raw pointer to the first element of the array with the given key_id + // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); + + // get ith C string from array with given key_id + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); + + GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id); + GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id); + + // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist) + GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key); + + // overrides an existing KV pair or adds a new one, the new KV pair is always at the back + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + + // creates a new array with n elements of the given type and copies the corresponding number of bytes from data + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n); + + // creates a new array with n strings and copies the corresponding strings from data + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src); + + // add tensor to GGUF context, tensor name must be unique + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + + // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated + // in such a way that the tensor data remains as one contiguous block (except for padding) + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + + // assumes that at least gguf_get_tensor_size bytes can be read from data + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data); + + // writing gguf files can be done in 3 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ false); + // + // - write only the meta data to a file, then re-open the file and append the tensor data: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ true); + // FILE * f = fopen(fname, "ab"); + // fwrite(f, ...); // write tensor data + // fclose(f); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // const size_t size_meta = gguf_get_meta_size(ctx); + // fseek(f, size_meta, SEEK_SET); + // fwrite(f, ...); // write tensor data + // void * data = malloc(size_meta); + // gguf_get_meta_data(ctx, data); + // rewind(f); + // fwrite(data, 1, data, f); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + + // writes the meta data to pointer "data" + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a5f7f7b5b..ae1cd2337 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -208,6 +208,7 @@ add_library(ggml-base ../include/ggml-backend.h ../include/ggml-cpp.h ../include/ggml-opt.h + ../include/gguf.h ggml.c ggml-alloc.c ggml-backend.cpp @@ -215,7 +216,8 @@ add_library(ggml-base ggml-threading.cpp ggml-threading.h ggml-quants.c - ggml-quants.h) + ggml-quants.h + gguf.cpp) target_include_directories(ggml-base PRIVATE .) @@ -290,9 +292,9 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA) ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512) ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) if (NOT MSVC) - # MSVC doesn't support AVX-VNNI or AMX - ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) + # MSVC doesn't support AMX ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() else () diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7ddd178b5..955ed505f 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -574,4 +574,9 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); + // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend + const char * backend_path = std::getenv("GGML_BACKEND_PATH"); + if (backend_path) { + ggml_backend_load(backend_path); + } } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index fdb4b986f..dba7be33b 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -764,7 +764,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1) { + if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); @@ -795,9 +795,12 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; - GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { + if (j == 0) { + GGML_LOG_DEBUG(": "); + } GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f0aecac1b..6b3641c42 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -215,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_DEFINITIONS GGML_SSE42) endif() if (GGML_AVX_VNNI) - # MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2 - #list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI) + list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI) endif() else () if (GGML_NATIVE) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 2d79b8b61..b311a5b1c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { } static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); return _mm256_dpbusd_epi32(zero, ax, sy); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_avx_epi32(zero, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); @@ -4166,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g buffer->buft = buft; buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor; buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; return buffer; } diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 634c5fa11..8e1472266 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 00f7f1170..c22a66287 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -54,6 +54,7 @@ #include "ggml-quants.h" #include +#include #ifdef _MSC_VER #define NOINLINE __declspec(noinline) @@ -1000,8 +1001,10 @@ class tinyBLAS_Q0_AVX { inline __m256 updot(__m256i u, __m256i s) { __m256i res; -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#elif defined(__AVXVNNI__) + res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif @@ -1049,6 +1052,704 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + + template + inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); + } + } + } + + template + inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); + } + } + + template + void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + int64_t i, j; + TA *aoffset = NULL; + VA *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; + VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; + VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; + VB t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); + C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); + C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); + C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); + C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); + + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + __builtin_vsx_disassemble_pair(c5, &C5); + __builtin_vsx_disassemble_pair(c6, &C6); + __builtin_vsx_disassemble_pair(c7, &C7); + __builtin_vsx_disassemble_pair(c8, &C8); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + t1 = vec_perm(c5[0], c6[0], swiz1); + t2 = vec_perm(c5[0], c6[0], swiz2); + t3 = vec_perm(c7[0], c8[0], swiz1); + t4 = vec_perm(c7[0], c8[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+128); + vec_xst(t6, 0, vecOffset+144); + vec_xst(t7, 0, vecOffset+160); + vec_xst(t8, 0, vecOffset+176); + + t1 = vec_perm(c5[1], c6[1], swiz1); + t2 = vec_perm(c5[1], c6[1], swiz2); + t3 = vec_perm(c7[1], c8[1], swiz1); + t4 = vec_perm(c7[1], c8[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+192); + vec_xst(t6, 0, vecOffset+208); + vec_xst(t7, 0, vecOffset+224); + vec_xst(t8, 0, vecOffset+240); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + aoffset5 += lda; + aoffset6 += lda; + aoffset7 += lda; + aoffset8 += lda; + vecOffset += 256; + i--; + } while(i > 0); + } + j--; + } while(j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); + + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + i = (cols >> 3); + if (i > 0) { + do { + switch(rows) { + case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + __builtin_vsx_disassemble_pair(c3, &C3); + case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + __builtin_vsx_disassemble_pair(c2, &C2); + case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + __builtin_vsx_disassemble_pair(c1, &C1); + break; + } + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance + // issues. After resolving them, below code will be enabled. + /*if (m_rem >= 16 && n_rem >= 8) { + mc = 16; + nc = 8; + gemm<16,8>(m0, m, n0, n); + } else if(m_rem >= 8 && n_rem >= 16) { + mc = 8; + nc = 16; + gemm<8,16>(m0, m, n0, n); + }*/ + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 4) { + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm_small<4, 4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem > 4)) { + nc = 4; + switch(m_rem) { + case 1: + mc = 1; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 2: + mc = 2; + gemm_small<2, 4>(m0, m, n0, n); + break; + case 3: + mc = 3; + gemm_small<3, 4>(m0, m, n0, n); + break; + default: + return; + } + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 2: + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 3: + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small<3, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small<3, 1>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small<2, 4>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small<2, 2>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small<2, 1>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small<1, 3>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small<1, 1>(m0, m, n0, n); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[16] = {0}; + acc_t acc_0, acc_1; + std::array comparray; + vector float fin_res[8] = {0}; + vector float vs[8] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + } + for (int I = 0; I<4; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 4; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii, jj+4, 4, fin_res); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[8] = {0}; + acc_t acc_0, acc_1; + std::array comparray; + vector float fin_res[8] = {0}; + vector float vs[8] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + } + for (int I = 0; I<8; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii+4, jj, 4, fin_res); + } + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[16] = {0}; + acc_t acc_0, acc_1, acc_2, acc_3; + std::array comparray; + vector float fin_res[16] = {0}; + vector float vs[16] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + } + for (int I = 0; I<8; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); + compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); + compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res<4, 4>(ii, jj+4, 8, fin_res); + save_res<4, 4>(ii+4, jj+4, 12, fin_res); + } + + template + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + vec_t vec_A[8], vec_B[8] = {0}; + vector signed int vec_C[4]; + acc_t acc_0; + + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + std::array comparray; + vector float res[4] = {0}; + vector float fin_res[4] = {0}; + vector float vs[4] = {0}; + vector float CA[4] = {0}; + __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + for (int l = 0; l < k; l++) { + __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(&acc_0); + packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x+=4) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + } + for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + } + } + __builtin_mma_disassemble_acc(vec_C, &acc_0); + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < RM; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + + for (int i = 0; i < RM; i++) { + CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); + } + } + save_res(ii, jj, 0, fin_res); + } + } + + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + kernel(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + TA *At; + TB *Bt; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + template class tinyBLAS_PPC { public: @@ -1068,13 +1769,17 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) { + template + void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { int64_t i, j; - float *aoffset = NULL, *boffset = NULL; - float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - - aoffset = const_cast(a); + TA *aoffset = NULL, *boffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; + VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + VA t1, t2, t3, t4, t5, t6, t7, t8; + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { @@ -1090,9 +1795,6 @@ class tinyBLAS_PPC { aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2]; - vector float t1, t2, t3, t4, t5, t6, t7, t8; do { C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); @@ -1172,21 +1874,19 @@ class tinyBLAS_PPC { } while(i > 0); } if (cols & 4) { - vector float c1, c2, c3, c4, c5, c6, c7, c8; - vector float t1, t2, t3, t4, t5, t6, t7, t8; - c1 = vec_xl(0, aoffset1); - c2 = vec_xl(0, aoffset2); - c3 = vec_xl(0, aoffset3); - c4 = vec_xl(0, aoffset4); - c5 = vec_xl(0, aoffset5); - c6 = vec_xl(0, aoffset6); - c7 = vec_xl(0, aoffset7); - c8 = vec_xl(0, aoffset8); + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); + c4[0] = vec_xl(0, aoffset4); + c5[0] = vec_xl(0, aoffset5); + c6[0] = vec_xl(0, aoffset6); + c7[0] = vec_xl(0, aoffset7); + c8[0] = vec_xl(0, aoffset8); - t1 = vec_mergeh(c1, c2); - t2 = vec_mergeh(c3, c4); - t3 = vec_mergeh(c5, c6); - t4 = vec_mergeh(c7, c8); + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergeh(c5[0], c6[0]); + t4 = vec_mergeh(c7[0], c8[0]); t5 = vec_xxpermdi(t1, t2, 0); t6 = vec_xxpermdi(t3, t4, 0); t7 = vec_xxpermdi(t1, t2, 3); @@ -1196,10 +1896,10 @@ class tinyBLAS_PPC { vec_xst(t7, 0, boffset+8); vec_xst(t8, 0, boffset+12); - t1 = vec_mergel(c1, c2); - t2 = vec_mergel(c3, c4); - t3 = vec_mergel(c5, c6); - t4 = vec_mergel(c7, c8); + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_mergel(c5[0], c6[0]); + t4 = vec_mergel(c7[0], c8[0]); t5 = vec_xxpermdi(t1, t2, 0); t6 = vec_xxpermdi(t3, t4, 0); t7 = vec_xxpermdi(t1, t2, 3); @@ -1221,9 +1921,6 @@ class tinyBLAS_PPC { aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { - __vector_pair C1, C2, C3, C4; - vector float c1[2], c2[2], c3[2], c4[2]; - vector float t1, t2, t3, t4, t5, t6, t7, t8; do { C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); @@ -1270,22 +1967,20 @@ class tinyBLAS_PPC { } if (cols & 4) { - vector float c1, c2, c3, c4; - vector float t1, t2, t3, t4; - c1 = vec_xl(0, aoffset1); - c2 = vec_xl(0, aoffset2); - c3 = vec_xl(0, aoffset3); - c4 = vec_xl(0, aoffset4); + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); + c4[0] = vec_xl(0, aoffset4); - t1 = vec_mergeh(c1, c2); - t2 = vec_mergeh(c3, c4); + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); t3 = vec_xxpermdi(t1, t2, 0); t4 = vec_xxpermdi(t1, t2, 3); vec_xst(t3, 0, boffset); vec_xst(t4, 0, boffset+4); - t1 = vec_mergel(c1, c2); - t2 = vec_mergel(c3, c4); + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); t3 = vec_xxpermdi(t1, t2, 0); t4 = vec_xxpermdi(t1, t2, 3); vec_xst(t3, 0, boffset+8); @@ -1297,21 +1992,19 @@ class tinyBLAS_PPC { aoffset2 = aoffset1 + lda; aoffset3 = aoffset2 + lda; if (cols & 4) { - vector float c1, c2, c3, c4 = {0}; - vector float t1, t2, t3, t4; - c1 = vec_xl(0, aoffset1); - c2 = vec_xl(0, aoffset2); - c3 = vec_xl(0, aoffset3); + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); - t1 = vec_mergeh(c1, c2); - t2 = vec_mergeh(c3, c4); + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); t3 = vec_xxpermdi(t1, t2, 0); t4 = vec_xxpermdi(t1, t2, 3); vec_xst(t3, 0, boffset); vec_xst(t4, 0, boffset+4); - t1 = vec_mergel(c1, c2); - t2 = vec_mergel(c3, c4); + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); t3 = vec_xxpermdi(t1, t2, 0); t4 = vec_xxpermdi(t1, t2, 3); vec_xst(t3, 0, boffset+8); @@ -1319,14 +2012,13 @@ class tinyBLAS_PPC { } } } - void KERNEL_4x4(int64_t ii, int64_t jj) { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -1341,8 +2033,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -1362,8 +2054,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); - READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -1385,8 +2077,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); - READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -1569,15 +2261,15 @@ class tinyBLAS_PPC { vec_t vec_A[4], vec_B[4]; for (int l=0; l= 4 && RM == 1) { - float* a = const_cast(A+(ii)*lda+l); - READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + TA* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); } else { - READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); - READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -1587,7 +2279,7 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); + *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); } } } @@ -1810,6 +2502,20 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; + +#elif defined(__MMA__) + if (n < 8 && n != 4) + return false; + if (m < 8 && m != 4) + return false; + tinyBLAS_Q0_PPC tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + #else return false; #endif diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 2f42b8a95..aafbaf803 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb1, uint64_t nb2, uint64_t nb3){ - static_assert(dim >= 0 && dim <= 3); + static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 3896f956d..5b0dfacef 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c180adc84..0b06be729 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) @@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_BF16: #ifdef GGML_USE_MUSA if (a->type == GGML_TYPE_Q3_K) { return false; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index a4b4f6bc1..ac45f2d17 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -1,9 +1,9 @@ #include "common.cuh" #include "mmv.cuh" -template +template static __global__ void mul_mat_vec( - const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { const int64_t row = blockIdx.x; const int64_t channel = blockIdx.z; @@ -13,7 +13,6 @@ static __global__ void mul_mat_vec( y += channel *stride_channel_y; dst += channel *stride_channel_dst; - const half2 * x2 = (const half2 *) x; const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; @@ -28,28 +27,44 @@ static __global__ void mul_mat_vec( float sumf; - if (std::is_same::value) { + if constexpr (std::is_same::value) { + const half2 * x2 = (const half2 *) x; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same::value) { + const int * x2 = (const int *) x; sumf = 0.0f; for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = __half22float2(x2[col2]); + const int tmpx = x2[col2]; const float2 tmpy = y2[col2]; - sumf += tmpx.x * tmpy.x; - sumf += tmpx.y * tmpy.y; + sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; } } else { -#ifdef FP16_AVAILABLE - half2 sumh2 = make_half2(0.0f, 0.0f); - - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmp = y2[col2]; - sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); - } - - sumf = __low2float(sumh2) + __high2float(sumh2); -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE + static_assert(std::is_same::value, "unsupported type"); } sumf = warp_reduce_sum(sumf); @@ -71,9 +86,9 @@ static __global__ void mul_mat_vec( dst[row] = sumf; } -template +template static void launch_mul_mat_vec_cuda( - const half * x, const float * y, float * dst, + const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, cudaStream_t stream) { @@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda( const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 64: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 96: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 128: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 160: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 192: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 224: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 256: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; default: { @@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda( } } +template static void mul_mat_vec_cuda( - const half * x, const float * y, float * dst, + const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, enum ggml_prec prec, cudaStream_t stream) { switch (prec) { case GGML_PREC_DEFAULT: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, stream); } break; case GGML_PREC_F32: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, stream); } break; } } void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const half * src0_d = (const half *) src0->data; const float * src1_d = (const float *) src1->data; float * dst_d = (float *) dst->data; @@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } } void ggml_cuda_op_mul_mat_vec( @@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec( const int64_t channel_stride_y = 0; const int64_t channel_stride_dst = 0; - mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } GGML_UNUSED(ctx); GGML_UNUSED(src1); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index db9f6a165..1746b0732 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #if CUDART_VERSION < 11020 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 3205534d6..c905b15d7 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" @@ -121,6 +122,8 @@ #define __has_builtin(x) 0 #endif +typedef hip_bfloat16 nv_bfloat16; + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1604b8229..6cc1b69ee 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F @@ -132,3 +133,5 @@ #define cudaKernelNodeParams musaKernelNodeParams #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamEndCapture musaStreamEndCapture + +typedef mt_bfloat16 nv_bfloat16; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 549772c57..eab017889 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -3,6 +3,8 @@ // GGML internal header #include "ggml.h" +#include "gguf.h" + #include #include #include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ @@ -551,22 +553,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) -// expose GGUF internals for test code - -GGML_API size_t gguf_type_size(enum gguf_type type); - -GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); - -struct gguf_buf { - void * data; - size_t size; - size_t offset; -}; -GGML_API struct gguf_buf gguf_buf_init(size_t size); -GGML_API void gguf_buf_free(struct gguf_buf buf); - -GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta); - #ifdef __cplusplus } #endif + +#ifdef __cplusplus +#include + +// expose GGUF internals for test code +GGML_API size_t gguf_type_size(enum gguf_type type); +GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); +GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta); +#endif // __cplusplus diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 1bad27206..89fcde2fa 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -103,3 +103,19 @@ else() DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib ) endif() # GGML_METAL_EMBED_LIBRARY + +if (NOT GGML_METAL_EMBED_LIBRARY) + install( + FILES src/ggml-metal/ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) + + install( + FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 28f590f92..a85502ee0 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2067,8 +2067,8 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; + const uint32_t r2 = ne12/ne02; + const uint32_t r3 = ne13/ne03; // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c77d629f0..ed90e471a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2744,13 +2744,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_image_format img_fmt_1d; cl_image_desc img_desc_1d; cl_buffer_region region; - cl_mem A_image1d; - cl_mem B_image1d; - cl_mem B_sub_buffer; - cl_mem C_d; + cl_mem A_image1d = nullptr; + cl_mem B_image1d = nullptr; + cl_mem B_sub_buffer = nullptr; + cl_mem C_d = nullptr; // for B transpose - cl_mem B_d; - cl_mem B_d_input_image; + cl_mem B_d = nullptr; + cl_mem B_d_input_image = nullptr; // <--------------------------------------------> // // define matrix dimensions diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 431082426..63da2b86b 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -27,15 +27,6 @@ #endif #include -#define UNUSED GGML_UNUSED - -#define GGML_DEBUG 0 -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - #ifdef _WIN32 typedef SOCKET sockfd_t; using ssize_t = __int64; @@ -93,9 +84,23 @@ enum rpc_cmd { RPC_CMD_COPY_TENSOR, RPC_CMD_GRAPH_COMPUTE, RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_COUNT, }; +struct rpc_msg_get_alloc_size_req { + rpc_tensor tensor; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + struct rpc_msg_alloc_buffer_req { uint64_t size; }; @@ -397,7 +402,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { initialized = true; } #else - UNUSED(initialized); + GGML_UNUSED(initialized); #endif auto sock = socket_connect(host.c_str(), port); if (sock == nullptr) { @@ -461,10 +466,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - UNUSED(buffer); - if (ggml_is_quantized(tensor->type)) { - // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized - GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + GGML_ASSERT(status); } } @@ -577,8 +590,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { } static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - UNUSED(buft); - return ggml_nbytes(tensor); + // See comments in init_tensor. + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request; + + request.tensor = serialize_tensor(tensor); + + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + GGML_ASSERT(status); + + return response.alloc_size; + } else { + return ggml_nbytes(tensor); + } } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -603,7 +631,7 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) { } static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { - UNUSED(backend); + GGML_UNUSED(backend); // this is no-op because we don't have any async operations } @@ -757,6 +785,8 @@ public: bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -770,6 +800,36 @@ private: std::unordered_set buffers; }; +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + ggml_free(ctx); + return false; + } + + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backend); + } else { + buft = tensor->buffer->buft; + } + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); + + ggml_free(ctx); + return true; +} + void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); @@ -781,7 +841,7 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); } } @@ -803,7 +863,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } void * base = ggml_backend_buffer_get_base(buffer); @@ -815,7 +875,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } ggml_backend_buffer_free(buffer); @@ -827,7 +887,7 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } ggml_backend_buffer_clear(buffer, request.value); @@ -883,7 +943,7 @@ bool rpc_server::set_tensor(const std::vector & input) { struct ggml_context * ctx = ggml_init(params); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); ggml_free(ctx); return false; } @@ -905,6 +965,40 @@ bool rpc_server::set_tensor(const std::vector & input) { return true; } +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); + ggml_free(ctx); + return false; + } + + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + return true; +} + bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -914,7 +1008,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< struct ggml_context * ctx = ggml_init(params); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); ggml_free(ctx); return false; } @@ -948,7 +1042,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); ggml_free(ctx); return false; } @@ -1058,6 +1152,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + server.get_alloc_size(request, response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_GET_ALIGNMENT: { if (!recv_msg(sockfd, nullptr, 0)) { return; @@ -1133,6 +1239,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sockfd, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { @@ -1257,14 +1376,14 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - UNUSED(dev); + GGML_UNUSED(dev); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { // TODO: obtain value from the server return GGML_BACKEND_DEVICE_TYPE_GPU; - UNUSED(dev); + GGML_UNUSED(dev); } static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { @@ -1285,7 +1404,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const return ggml_backend_rpc_init(ctx->endpoint.c_str()); - UNUSED(params); + GGML_UNUSED(params); } static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -1293,12 +1412,12 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); - UNUSED(dev); + GGML_UNUSED(dev); } static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - UNUSED(dev); - UNUSED(op); + GGML_UNUSED(dev); + GGML_UNUSED(op); //TODO: call the remote backend and cache the results return true; } @@ -1335,20 +1454,20 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { return "RPC"; - UNUSED(reg); + GGML_UNUSED(reg); } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { return 0; - UNUSED(reg); + GGML_UNUSED(reg); } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - UNUSED(reg); - UNUSED(index); + GGML_UNUSED(reg); + GGML_UNUSED(index); } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { @@ -1357,7 +1476,7 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch } return NULL; - UNUSED(reg); + GGML_UNUSED(reg); } static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp index 75ddfb86a..105db6f03 100644 --- a/ggml/src/ggml-sycl/wkv6.cpp +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -131,7 +131,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s [=](sycl::nd_item<3> item_ct1) { rwkv_wkv_f32_kernel( B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, - item_ct1, shared_mem_acc.get_pointer() + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() ); }); }); diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 6d46e5f24..c0ddaac82 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,20 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) + # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + else() + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + endif() + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. # If it's not, there will be an error to stderr. # If it's supported, set a define to indicate that we should compile those shaders @@ -69,6 +83,10 @@ if (Vulkan_FOUND) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + if (NOT CMAKE_CROSSCOMPILING) + set(_ggml_vk_genshaders_cmd "$/${_ggml_vk_genshaders_cmd}") + endif () + add_custom_command( OUTPUT ${_ggml_vk_header} ${_ggml_vk_source} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8e47e79ae..077452424 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -145,6 +145,8 @@ class vk_perf_logger; #endif static void ggml_vk_destroy_buffer(vk_buffer& buf); +static constexpr uint32_t mul_mat_vec_max_cols = 8; + struct vk_device_struct { std::mutex mutex; @@ -202,8 +204,8 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; @@ -1643,6 +1645,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_MM2 } else #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat_support) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -1737,7 +1740,9 @@ static void ggml_vk_load_shaders(vk_device& device) { } #undef CREATE_MM2 #undef CREATE_MM - } else if (device->fp16) { + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l) \ @@ -1866,33 +1871,35 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (device->vendor_id == VK_VENDOR_ID_INTEL) rm_stdq = 2; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -2036,6 +2043,8 @@ static void ggml_vk_load_shaders(vk_device& device) { std::cerr << "Done!" << std::endl; } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); + static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2171,9 +2180,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { - // Intel drivers don't support coopmat properly yet - // Only RADV supports coopmat properly on AMD + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { device->coopmat_support = false; } @@ -2238,6 +2245,7 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; } +#if defined(VK_KHR_cooperative_matrix) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; coopmat_features.pNext = nullptr; coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; @@ -2247,6 +2255,7 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features; } +#endif #if defined(VK_NV_cooperative_matrix2) VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; @@ -2279,7 +2288,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_EXT_subgroup_size_control"); } +#if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; +#endif if (coopmat2_support) { #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -2372,6 +2383,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_float16_int8"); } +#if defined(VK_KHR_cooperative_matrix) if (device->coopmat_support) { // Query supported shapes std::vector cm_props; @@ -2438,7 +2450,7 @@ static vk_device ggml_vk_get_device(size_t idx) { if (device->coopmat_support) { device_extensions.push_back("VK_KHR_cooperative_matrix"); } - +#endif device->name = GGML_VK_NAME + std::to_string(idx); device_create_info = { @@ -2511,7 +2523,6 @@ static vk_device ggml_vk_get_device(size_t idx) { return vk_instance.devices[idx]; } - static void ggml_vk_print_gpu_info(size_t idx) { GGML_ASSERT(idx < vk_instance.device_indices.size()); size_t dev_num = vk_instance.device_indices[idx]; @@ -2550,9 +2561,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { fp16_storage = true; } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { fp16_compute = true; - } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { coopmat_support = true; +#endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { @@ -2561,9 +2574,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } - if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { - // Intel drivers don't support coopmat properly yet - // Only RADV supports coopmat properly on AMD + if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { coopmat_support = false; } @@ -2592,6 +2603,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { // Pointer to the last chain element VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; coopmat_features.pNext = nullptr; coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; @@ -2607,6 +2619,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { fp16 = fp16 && vk12_features.shaderFloat16; coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; +#endif std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; @@ -2892,9 +2905,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); switch (a_type) { case GGML_TYPE_F32: @@ -2915,7 +2929,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * return nullptr; } - return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type]; + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { @@ -3925,8 +3939,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; - GGML_ASSERT(ne11 == 1); - const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; const uint64_t ne22 = dst->ne[2]; @@ -3935,6 +3947,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; @@ -3985,7 +4002,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); @@ -4057,8 +4074,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } - uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -4081,7 +4100,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21), + stride_batch_x, stride_batch_y, stride_batch_d, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; ggml_vk_sync_buffers(subctx); @@ -4261,7 +4280,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); - } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); @@ -8075,6 +8097,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Intel drivers don't support coopmat properly yet + return false; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + const std::string name = props.deviceName; + return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs + name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs + name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + } + return true; + default: + return true; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 187c31916..24875cdcf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -9,9 +9,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - #if !defined(DATA_A_F32) && !defined(DATA_A_F16) #define K_PER_ITER 8 #else @@ -21,70 +18,70 @@ layout (constant_id = 1) const uint NUM_ROWS = 1; uint a_offset, b_offset, d_offset, y_offset; -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - -void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { - const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; - const uint iqs = (col%QUANT_K)/QUANT_R; // quant index - const uint iybs = col - col%QUANT_K; // y block start index + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index #if K_PER_ITER == 8 #if QUANT_R == 2 - const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4]; - const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4]; - const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); - const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); + const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); #else - const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]); - const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]); + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]); - } + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } #endif - uint ibi = first_row*p.ncols; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib = (ibi + col)/QUANT_K; // block index - ibi += p.ncols; + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; #if K_PER_ITER == 8 - vec4 v = dequantize4(ib, iqs, a_offset); - vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); - const vec2 dm = get_dm(ib, a_offset); - if (dm.y != 0) { // quant has min component - v = v * dm.x + dm.y; - v2 = v2 * dm.x + dm.y; - } + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } - // matrix multiplication - FLOAT_TYPE rowtmp = dot(bv0, v); - rowtmp += dot(bv1, v2); + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); - if (dm.y == 0) - rowtmp *= dm.x; + if (dm.y == 0) + rowtmp *= dm.x; - temp[n] += rowtmp; + temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); + const vec2 v = dequantize(ib, iqs, a_offset); - // matrix multiplication - temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]); - if (!OOB) { - temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]); - } + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } #endif + } } } @@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); @@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 3894fca82..903753c7e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx * p.batch_stride_d; #endif } + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 138ad0184..934213446 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,11 +5,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; - B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0]; - B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56]; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; f16vec2 d = data_a[ib0 + i].d; @@ -74,48 +62,42 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uvec2 qs0 = uvec2(unpack8(qs0_u16)); uvec2 qs16 = uvec2(unpack8(qs16_u16)); - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } - temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n])); } } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 82ec42d25..86b0159d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,11 +5,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -33,10 +28,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } const uint s_shift = 4 * v_im; @@ -44,15 +41,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; - B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0]; - B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56]; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); @@ -70,39 +58,34 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { u8vec2 s8 = unpack8(s8_16); u8vec2 s10 = unpack8(s10_16); - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); } - temp[n] = fma(d, sum, temp[n]); } } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 677c207a8..cd1dd8e89 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,11 +6,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -36,21 +31,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; - B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4]; - B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8]; - B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4]; - B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8]; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; f16vec2 d = data_a[ib0 + i].d; @@ -103,37 +95,27 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t q4_14 = qs64_hi4.z; const uint32_t q4_15 = qs64_hi4.w; - const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n])); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; + B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; + B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; + B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } } } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index ed3c25d89..0a68891c3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -6,11 +6,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -33,25 +28,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; - B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2]; - B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8]; - B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16]; - B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24]; - B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2]; - B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8]; - B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16]; - B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24]; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; f16vec2 d = data_a[ib0 + i].d; @@ -116,53 +104,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t q4_14 = qs64_80_hi4.z; const uint32_t q4_15 = qs64_80_hi4.w; - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), q4_0, - fma(FLOAT_TYPE(by10.y), q4_1, - fma(FLOAT_TYPE(by116.x), q4_2, - FLOAT_TYPE(by116.y) * q4_3))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), q4_4, - fma(FLOAT_TYPE(by132.y), q4_5, - fma(FLOAT_TYPE(by148.x), q4_6, - FLOAT_TYPE(by148.y) * q4_7))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), q4_8, - fma(FLOAT_TYPE(by20.y), q4_9, - fma(FLOAT_TYPE(by216.x), q4_10, - FLOAT_TYPE(by216.y) * q4_11))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), q4_12, - fma(FLOAT_TYPE(by232.y), q4_13, - fma(FLOAT_TYPE(by248.x), q4_14, - FLOAT_TYPE(by248.y) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, - fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, - fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, - (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n])); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; + B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; + B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; + B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; + B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; + B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; + B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; + B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } } } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index fab4ff5ff..70e13a56b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,11 +6,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; - -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -36,20 +31,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } } [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; - B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4]; - B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8]; - B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16]; - B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24]; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); @@ -84,35 +76,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uvec4 q2 = uvec4(unpack8(q2_u32)); uvec4 q3 = uvec4(unpack8(q3_u32)); - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 4; ++l) { + sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), + fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), + fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), + fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + } + temp[j][n] += sum * d; } - temp[n] += sum * d; } } - // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } - } - barrier(); - } - if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } + reduce_result(temp, d_offset, first_row, num_rows, tid); } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp new file mode 100644 index 000000000..8c5dd1bd1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_KHR_cooperative_matrix : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8111c0638..7b5044798 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -342,9 +342,11 @@ void process_shaders() { matmul_shaders(true, matmul_id, false, false, false); matmul_shaders(true, matmul_id, false, false, true); +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc matmul_shaders(true, matmul_id, true, false, false); matmul_shaders(true, matmul_id, true, false, true); +#endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // Coopmat2, fp32acc and fp16acc diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bbe5f482..90abc6ad4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1588,15 +1588,8 @@ static struct ggml_tensor * ggml_new_tensor_impl( struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); -#ifdef __clang__ - // temporary until ggml_tensor::backend is removed - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wdeprecated-declarations" -#endif - *result = (struct ggml_tensor) { /*.type =*/ type, - /*.backend =*/ GGML_BACKEND_TYPE_CPU, /*.buffer =*/ NULL, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, @@ -1612,10 +1605,6 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.padding =*/ { 0 }, }; -#ifdef __clang__ - #pragma clang diagnostic pop -#endif - // TODO: this should not be needed as long as we don't rely on aligned SIMD loads //GGML_ASSERT_ALIGNED(result->data); @@ -6417,1271 +6406,6 @@ size_t ggml_quantize_chunk( //////////////////////////////////////////////////////////////////////////////// -struct gguf_str { - uint64_t n; // GGUFv2 - char * data; -}; - -static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = sizeof(uint8_t), - [GGUF_TYPE_INT8] = sizeof(int8_t), - [GGUF_TYPE_UINT16] = sizeof(uint16_t), - [GGUF_TYPE_INT16] = sizeof(int16_t), - [GGUF_TYPE_UINT32] = sizeof(uint32_t), - [GGUF_TYPE_INT32] = sizeof(int32_t), - [GGUF_TYPE_FLOAT32] = sizeof(float), - [GGUF_TYPE_BOOL] = sizeof(bool), - [GGUF_TYPE_STRING] = sizeof(struct gguf_str), - [GGUF_TYPE_UINT64] = sizeof(uint64_t), - [GGUF_TYPE_INT64] = sizeof(int64_t), - [GGUF_TYPE_FLOAT64] = sizeof(double), - [GGUF_TYPE_ARRAY] = 0, // undefined -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = "u8", - [GGUF_TYPE_INT8] = "i8", - [GGUF_TYPE_UINT16] = "u16", - [GGUF_TYPE_INT16] = "i16", - [GGUF_TYPE_UINT32] = "u32", - [GGUF_TYPE_INT32] = "i32", - [GGUF_TYPE_FLOAT32] = "f32", - [GGUF_TYPE_BOOL] = "bool", - [GGUF_TYPE_STRING] = "str", - [GGUF_TYPE_ARRAY] = "arr", - [GGUF_TYPE_UINT64] = "u64", - [GGUF_TYPE_INT64] = "i64", - [GGUF_TYPE_FLOAT64] = "f64", -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -union gguf_value { - uint8_t uint8; - int8_t int8; - uint16_t uint16; - int16_t int16; - uint32_t uint32; - int32_t int32; - float float32; - uint64_t uint64; - int64_t int64; - double float64; - bool bool_; - - struct gguf_str str; - - struct { - enum gguf_type type; - - uint64_t n; // GGUFv2 - void * data; - } arr; -}; - -struct gguf_kv { - struct gguf_str key; - - enum gguf_type type; - union gguf_value value; -}; - -struct gguf_header { - char magic[4]; - - uint32_t version; - uint64_t n_tensors; // GGUFv2 - uint64_t n_kv; // GGUFv2 -}; - -struct gguf_tensor_info { - struct gguf_str name; - - uint32_t n_dims; - uint64_t ne[GGML_MAX_DIMS]; - - enum ggml_type type; - - uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` - - // for writing API - const void * data; - size_t size; -}; - -struct gguf_context { - struct gguf_header header; - - struct gguf_kv * kv; - struct gguf_tensor_info * infos; - - size_t alignment; - size_t offset; // offset of `data` from beginning of file - size_t size; // size of `data` in bytes - - //uint8_t * padding; - void * data; -}; - -size_t gguf_type_size(enum gguf_type type) { - GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT); - return GGUF_TYPE_SIZE[type]; -} - -static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) { - if (info->n_dims > GGML_MAX_DIMS) { - fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims); - return false; - } - - if (info->type < 0 || info->type >= GGML_TYPE_COUNT) { - fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type); - return false; - } - - if (strlen(info->name.data) >= GGML_MAX_NAME) { - fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data); - return false; - } - - for (uint32_t i = 0; i < info->n_dims; ++i) { - if (info->ne[i] <= 0) { - fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]); - return false; - } - } - - // prevent overflow for total number of elements - if (INT64_MAX/info->ne[1] <= info->ne[0]) { - fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]); - return false; - } - - if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) { - fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]); - return false; - } - - if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) { - fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]); - return false; - } - - return true; -} - -static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { - const size_t n = fread(dst, 1, size, file); - *offset += n; - return n == size; -} - -static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); - - // early exit if string length is invalid, prevents from integer overflow - if (p->n == SIZE_MAX) { - fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n); - return false; - } - - p->data = calloc(p->n + 1, 1); - if (!p->data) { - fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n); - return false; - } - - ok = ok && gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - -static void gguf_free_kv(struct gguf_kv * kv) { - if (kv->key.data) { - GGML_FREE(kv->key.data); - } - - if (kv->type == GGUF_TYPE_STRING) { - if (kv->value.str.data) { - GGML_FREE(kv->value.str.data); - } - } - - if (kv->type == GGUF_TYPE_ARRAY) { - if (kv->value.arr.data) { - if (kv->value.arr.type == GGUF_TYPE_STRING) { - for (uint64_t j = 0; j < kv->value.arr.n; ++j) { - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; - if (str->data) { - GGML_FREE(str->data); - } - } - } - GGML_FREE(kv->value.arr.data); - } - } -} - -struct gguf_context * gguf_init_empty(void) { - struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context)); - if (!ctx) { - fprintf(stderr, "%s: failed to allocate memory for context\n", __func__); - return NULL; - } - - memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); - ctx->header.version = GGUF_VERSION; - ctx->header.n_tensors = 0; - ctx->header.n_kv = 0; - - ctx->kv = NULL; - ctx->infos = NULL; - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - ctx->offset = 0; - ctx->size = 0; - - ctx->data = NULL; - - return ctx; -} - -struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { - // offset from start of file - size_t offset = 0; - - char magic[4]; - - // check the magic before making allocations - { - gguf_fread_el(file, &magic, sizeof(magic), &offset); - - for (uint32_t i = 0; i < sizeof(magic); i++) { - if (magic[i] != GGUF_MAGIC[i]) { - fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]); - return NULL; - } - } - } - - bool ok = true; - - struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context)); - if (!ctx) { - fprintf(stderr, "%s: failed to allocate memory for context\n", __func__); - return NULL; - } - - // read the header - { - strncpy(ctx->header.magic, magic, 4); - - ctx->kv = NULL; - ctx->infos = NULL; - ctx->data = NULL; - - ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); - - if (ctx->header.version == 1) { - fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); - gguf_free(ctx); - return NULL; - } - - // sanity-checks to prevent from integer/buffer overflows - - ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info)); - ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead()); - ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_kv)); - - if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); - gguf_free(ctx); - return NULL; - } - } - - // read the kv pairs - { - const uint64_t n_kv = ctx->header.n_kv; - - if (n_kv > 0) { - ctx->kv = calloc(n_kv, sizeof(struct gguf_kv)); - if (!ctx->kv) { - fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__); - gguf_free(ctx); - return NULL; - } - } - - for (uint64_t i = 0; i < n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - //fprintf(stderr, "%s: reading kv %d\n", __func__, i); - - ok = ok && gguf_fread_str(file, &kv->key, &offset); - ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); - - //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); - - switch (kv->type) { - case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; - case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; - case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; - case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; - case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; - case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; - case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; - case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break; - case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break; - case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break; - case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; - case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; - case GGUF_TYPE_ARRAY: - { - ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); - ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - // prevent from integer overflow in the malloc below - if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) { - fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - gguf_free(ctx); - return NULL; - } - - kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type)); - if (!kv->value.arr.data) { - fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); - gguf_free(ctx); - return NULL; - } - - ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset); - } break; - case GGUF_TYPE_STRING: - { - // prevent from integer overflow in the malloc below - if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) { - fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - gguf_free(ctx); - return NULL; - } - - kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str)); - if (!kv->value.arr.data) { - fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); - gguf_free(ctx); - return NULL; - } - - for (uint64_t j = 0; j < kv->value.arr.n; ++j) { - ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); - } - } break; - case GGUF_TYPE_ARRAY: - default: - { - fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type); - ok = false; - } break; - } - } break; - default: - { - fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type); - ok = false; - } break; - } - - if (!ok) { - break; - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); - gguf_free(ctx); - return NULL; - } - } - - // read the tensor infos - if (ctx->header.n_tensors > 0) { - ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info)); - if (!ctx->infos) { - fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__); - gguf_free(ctx); - return NULL; - } - - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - info->ne[j] = 1; - } - - ok = ok && gguf_fread_str(file, &info->name, &offset); - ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); - - ok = ok && (info->n_dims <= GGML_MAX_DIMS); - - for (uint32_t j = 0; j < info->n_dims; ++j) { - ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); - } - - ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); - ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); - - ok = ok && gguf_tensor_info_sanitize(info); - - // make sure there is no duplicated tensor names - for (uint64_t j = 0; j < i && ok; ++j) { - if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { - fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); - ok = false; - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); - gguf_free(ctx); - return NULL; - } - } - } - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - - int alignment_idx = gguf_find_key(ctx, "general.alignment"); - if (alignment_idx != -1) { - ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset_pad = offset % ctx->alignment; - - if (offset_pad != 0) { - offset += ctx->alignment - offset_pad; - fseek(file, offset, SEEK_SET); - } - } - - // store the current file offset - this is where the data section starts - ctx->offset = offset; - - // compute the total size of the data section, taking into account the alignment - { - ctx->size = 0; - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const int64_t ne = - (int64_t) info->ne[0] * - (int64_t) info->ne[1] * - (int64_t) info->ne[2] * - (int64_t) info->ne[3]; - - if (ggml_blck_size(info->type) == 0 ) { - // this tensor type support have been removed: - fprintf(stderr, "%s: tensor '%s' of type %d: %s\n", - __func__, info->name.data, (int) info->type, ggml_type_name(info->type)); - gguf_free(ctx); - return NULL; - } - - if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", - __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); - gguf_free(ctx); - return NULL; - } - - const size_t size_cur = ggml_row_size(info->type, ne); - - ctx->size += GGML_PAD(size_cur, ctx->alignment); - } - } - - // load the tensor data only if requested - if (params.ctx != NULL) { - // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob - // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of - // the ggml_tensor structs to the appropriate locations in the binary blob - - // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (ctx->header.n_tensors )*ggml_tensor_overhead() : - (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; - - struct ggml_init_params pdata = { - .mem_size = mem_size, - .mem_buffer = NULL, - .no_alloc = params.no_alloc, - }; - - *params.ctx = ggml_init(pdata); - if (*params.ctx == NULL) { - fprintf(stderr, "%s: failed to initialize context\n", __func__); - gguf_free(ctx); - return NULL; - } - - struct ggml_context * ctx_data = *params.ctx; - - struct ggml_tensor * data = NULL; - - if (!params.no_alloc) { - data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); - - ok = ok && data != NULL; - - // read the binary blob with the tensor data - ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor data\n", __func__); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ctx->data = data->data; - } - - ggml_set_no_alloc(ctx_data, true); - - // create the tensors - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - const int64_t ne[GGML_MAX_DIMS] = { - ctx->infos[i].ne[0], - ctx->infos[i].ne[1], - ctx->infos[i].ne[2], - ctx->infos[i].ne[3], - }; - - struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); - - ok = ok && cur != NULL; - - if (!ok) { - break; - } - - ggml_set_name(cur, ctx->infos[i].name.data); - - // point the data member to the appropriate location in the binary blob using the tensor infos - if (!params.no_alloc) { - //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file - cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read the tensor data\n", __func__); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ggml_set_no_alloc(ctx_data, params.no_alloc); - } - - return ctx; -} - -struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { - FILE * file = ggml_fopen(fname, "rb"); - if (!file) { - fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); - return NULL; - } - - struct gguf_context * result = gguf_init_from_file_impl(file, params); - fclose(file); - return result; -} - -void gguf_free(struct gguf_context * ctx) { - if (ctx == NULL) { - return; - } - - if (ctx->kv) { - // free string memory - not great.. - for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { - gguf_free_kv(&ctx->kv[i]); - } - - GGML_FREE(ctx->kv); - } - - if (ctx->infos) { - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - if (info->name.data) { - GGML_FREE(info->name.data); - } - } - - GGML_FREE(ctx->infos); - } - - GGML_FREE(ctx); -} - -const char * gguf_type_name(enum gguf_type type) { - return GGUF_TYPE_NAME[type]; -} - -int gguf_get_version(const struct gguf_context * ctx) { - return ctx->header.version; -} - -size_t gguf_get_alignment(const struct gguf_context * ctx) { - return ctx->alignment; -} - -size_t gguf_get_data_offset(const struct gguf_context * ctx) { - return ctx->offset; -} - -void * gguf_get_data(const struct gguf_context * ctx) { - return ctx->data; -} - -int gguf_get_n_kv(const struct gguf_context * ctx) { - return ctx->header.n_kv; -} - -int gguf_find_key(const struct gguf_context * ctx, const char * key) { - // return -1 if key not found - int keyfound = -1; - - const int n_kv = gguf_get_n_kv(ctx); - - for (int i = 0; i < n_kv; ++i) { - if (strcmp(key, gguf_get_key(ctx, i)) == 0) { - keyfound = i; - break; - } - } - - return keyfound; -} - -const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - return ctx->kv[key_id].key.data; -} - -enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - return ctx->kv[key_id].type; -} - -enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.type; -} - -const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.data; -} - -const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - struct gguf_kv * kv = &ctx->kv[key_id]; - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; - return str->data; -} - -int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.n; -} - -uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); - return ctx->kv[key_id].value.uint8; -} - -int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); - return ctx->kv[key_id].value.int8; -} - -uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); - return ctx->kv[key_id].value.uint16; -} - -int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); - return ctx->kv[key_id].value.int16; -} - -uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); - return ctx->kv[key_id].value.uint32; -} - -int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); - return ctx->kv[key_id].value.int32; -} - -float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); - return ctx->kv[key_id].value.float32; -} - -uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); - return ctx->kv[key_id].value.uint64; -} - -int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); - return ctx->kv[key_id].value.int64; -} - -double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); - return ctx->kv[key_id].value.float64; -} - -bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); - return ctx->kv[key_id].value.bool_; -} - -const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); - return ctx->kv[key_id].value.str.data; -} - -const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); - GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING); - return &ctx->kv[key_id].value; -} - -int gguf_get_n_tensors(const struct gguf_context * ctx) { - return ctx->header.n_tensors; -} - -int gguf_find_tensor(const struct gguf_context * ctx, const char * name) { - // return -1 if tensor not found - int tensorfound = -1; - - const int n_tensors = gguf_get_n_tensors(ctx); - - for (int i = 0; i < n_tensors; ++i) { - if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { - tensorfound = i; - break; - } - } - - return tensorfound; -} - -size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) { - return ctx->infos[i].offset; -} - -char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { - return ctx->infos[i].name.data; -} - -enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) { - return ctx->infos[i].type; -} - -// returns the index -static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { - const int idx = gguf_find_key(ctx, key); - if (idx >= 0) { - return idx; - } - - const int n_kv = gguf_get_n_kv(ctx); - - ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); - ctx->kv[n_kv].key.n = strlen(key); - ctx->kv[n_kv].key.data = strdup(key); - ctx->header.n_kv++; - - return n_kv; -} - -void gguf_remove_key(struct gguf_context * ctx, const char * key) { - const int idx = gguf_find_key(ctx, key); - if (idx >= 0) { - const int n_kv = gguf_get_n_kv(ctx); - gguf_free_kv(&ctx->kv[idx]); - for (int i = idx; i < n_kv-1; ++i) { - ctx->kv[i] = ctx->kv[i+1]; - } - ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv)); - ctx->header.n_kv--; - } -} - -void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT8; - ctx->kv[idx].value.uint8 = val; -} - -void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT8; - ctx->kv[idx].value.int8 = val; -} - -void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT16; - ctx->kv[idx].value.uint16 = val; -} - -void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT16; - ctx->kv[idx].value.int16 = val; -} - -void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT32; - ctx->kv[idx].value.uint32 = val; -} - -void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT32; - ctx->kv[idx].value.int32 = val; -} - -void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT32; - ctx->kv[idx].value.float32 = val; -} - -void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT64; - ctx->kv[idx].value.uint64 = val; -} - -void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT64; - ctx->kv[idx].value.int64 = val; -} - -void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT64; - ctx->kv[idx].value.float64 = val; -} - -void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_BOOL; - ctx->kv[idx].value.bool_ = val; -} - -void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_STRING; - ctx->kv[idx].value.str.n = strlen(val); - ctx->kv[idx].value.str.data = strdup(val); -} - -void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = type; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type)); - memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type)); -} - -void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str)); - for (int i = 0; i < n; i++) { - struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; - str->n = strlen(data[i]); - str->data = strdup(data[i]); - } -} - -// set or add KV pairs from another context -void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { - for (uint32_t i = 0; i < src->header.n_kv; i++) { - switch (src->kv[i].type) { - case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; - case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; - case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; - case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; - case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; - case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; - case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; - case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break; - case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break; - case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break; - case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; - case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; - case GGUF_TYPE_ARRAY: - { - if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { - const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *)); - for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { - data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; - } - gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); - GGML_FREE((void *)data); - } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { - GGML_ABORT("nested arrays not supported"); - } else { - gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); - } - } break; - default: GGML_ABORT("invalid type"); - } - } -} - -void gguf_add_tensor( - struct gguf_context * ctx, - const struct ggml_tensor * tensor) { - GGML_ASSERT(tensor); - if (gguf_find_tensor(ctx, tensor->name) != -1) { - GGML_ABORT("duplicated tensor name"); - } - - const int idx = ctx->header.n_tensors; - ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); - - ctx->infos[idx].name.n = strlen(tensor->name); - ctx->infos[idx].name.data = strdup(tensor->name); - - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - ctx->infos[idx].ne[i] = 1; - } - - ctx->infos[idx].n_dims = ggml_n_dims(tensor); - for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) { - ctx->infos[idx].ne[i] = tensor->ne[i]; - } - - ctx->infos[idx].type = tensor->type; - ctx->infos[idx].offset = 0; - ctx->infos[idx].data = tensor->data; - ctx->infos[idx].size = ggml_nbytes(tensor); - - if (ctx->header.n_tensors > 0) { - ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); - } - - ctx->header.n_tensors++; -} - -void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ABORT("tensor not found"); - } - - ctx->infos[idx].type = type; -} - -void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ABORT("tensor not found"); - } - - ctx->infos[idx].data = data; - ctx->infos[idx].size = size; - - // update offsets - for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { - ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); - } -} - -//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { -// fwrite(&val->n, sizeof(val->n), 1, file); -// fwrite(val->data, sizeof(char), val->n, file); -//} -// -//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { -// fwrite(val, sizeof(char), size, file); -//} - -struct gguf_buf gguf_buf_init(size_t size) { - struct gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size), - /*buf.size =*/ size, - /*buf.offset =*/ 0, - }; - - return buf; -} - -void gguf_buf_free(struct gguf_buf buf) { - if (buf.data) { - GGML_FREE(buf.data); - } -} - -static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { - if (buf->offset + size > buf->size) { - buf->size = 1.5*(buf->offset + size); - if (buf->data) { - buf->data = realloc(buf->data, buf->size); - } - } -} - -static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { - gguf_buf_grow(buf, sizeof(val->n) + val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); - } - buf->offset += sizeof(val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val->data, val->n); - } - buf->offset += val->n; -} - -static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { - gguf_buf_grow(buf, el_size); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val, el_size); - } - buf->offset += el_size; -} - -void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { - // write header - gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); - gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); - gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); - gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); - - // write key-value pairs - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - gguf_bwrite_str(buf, &kv->key); - gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); - - switch (kv->type) { - case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; - case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; - case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; - case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; - case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; - case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; - case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; - case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break; - case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break; - case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break; - case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; - case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; - case GGUF_TYPE_ARRAY: - { - gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); - gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type)); - } break; - case GGUF_TYPE_STRING: - { - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { - gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); - } - } break; - case GGUF_TYPE_ARRAY: - default: GGML_ABORT("invalid type"); - } - } break; - default: GGML_ABORT("invalid type"); - } - } - - // write tensor infos - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - gguf_bwrite_str(buf, &info->name); - gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); - for (uint32_t j = 0; j < info->n_dims; ++j) { - gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); - } - gguf_bwrite_el(buf, &info->type, sizeof(info->type)); - gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset = buf->offset; - const size_t offset_pad = GGML_PAD(offset, ctx->alignment); - - if (offset_pad != offset) { - uint8_t pad = 0; - for (size_t i = 0; i < offset_pad - offset; ++i) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - } - - if (only_meta) { - return; - } - - size_t offset = 0; - - // write tensor data - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const size_t size = info->size; - const size_t size_pad = GGML_PAD(size, ctx->alignment); - - gguf_bwrite_el(buf, info->data, size); - - if (size_pad != size) { - uint8_t pad = 0; - for (size_t j = 0; j < size_pad - size; ++j) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - - GGML_ASSERT(offset == info->offset); - - offset += size_pad; - } -} - -void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { - FILE * file = ggml_fopen(fname, "wb"); - if (!file) { - GGML_ABORT("failed to open file for writing"); - } - - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, only_meta); - - fwrite(buf.data, 1, buf.offset, file); - - gguf_buf_free(buf); - - fclose(file); -} - -size_t gguf_get_meta_size(const struct gguf_context * ctx) { - // no allocs - only compute size - struct gguf_buf buf = gguf_buf_init(0); - - gguf_write_to_buf(ctx, &buf, true); - - return buf.offset; -} - -void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, true); - - memcpy(data, buf.data, buf.offset); - - gguf_buf_free(buf); -} - void ggml_log_set(ggml_log_callback log_callback, void * user_data) { g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default; g_logger_state.log_callback_user_data = user_data; diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp new file mode 100644 index 000000000..655ed600a --- /dev/null +++ b/ggml/src/gguf.cpp @@ -0,0 +1,1325 @@ +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "gguf.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +struct type_to_gguf_type; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_BOOL; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_STRING; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64; +}; + +static const std::map GGUF_TYPE_SIZE = { + {GGUF_TYPE_UINT8, sizeof(uint8_t)}, + {GGUF_TYPE_INT8, sizeof(int8_t)}, + {GGUF_TYPE_UINT16, sizeof(uint16_t)}, + {GGUF_TYPE_INT16, sizeof(int16_t)}, + {GGUF_TYPE_UINT32, sizeof(uint32_t)}, + {GGUF_TYPE_INT32, sizeof(int32_t)}, + {GGUF_TYPE_FLOAT32, sizeof(float)}, + {GGUF_TYPE_BOOL, sizeof(int8_t)}, + {GGUF_TYPE_STRING, 0}, // undefined + {GGUF_TYPE_ARRAY, 0}, // undefined + {GGUF_TYPE_UINT64, sizeof(uint64_t)}, + {GGUF_TYPE_INT64, sizeof(int64_t)}, + {GGUF_TYPE_FLOAT64, sizeof(double)}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +static const std::map GGUF_TYPE_NAME = { + {GGUF_TYPE_UINT8, "u8"}, + {GGUF_TYPE_INT8, "i8"}, + {GGUF_TYPE_UINT16, "u16"}, + {GGUF_TYPE_INT16, "i16"}, + {GGUF_TYPE_UINT32, "u32"}, + {GGUF_TYPE_INT32, "i32"}, + {GGUF_TYPE_FLOAT32, "f32"}, + {GGUF_TYPE_BOOL, "bool"}, + {GGUF_TYPE_STRING, "str"}, + {GGUF_TYPE_ARRAY, "arr"}, + {GGUF_TYPE_UINT64, "u64"}, + {GGUF_TYPE_INT64, "i64"}, + {GGUF_TYPE_FLOAT64, "f64"}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +size_t gguf_type_size(enum gguf_type type) { + auto it = GGUF_TYPE_SIZE.find(type); + return it == GGUF_TYPE_SIZE.end() ? 0 : it->second; +} + +struct gguf_kv { + std::string key; + + bool is_array; + enum gguf_type type; + + std::vector data; + std::vector data_string; + + template + gguf_kv(const std::string & key, const T value) + : key(key), is_array(false), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(sizeof(T)); + memcpy(data.data(), &value, sizeof(T)); + } + + template + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(value.size()*sizeof(T)); + for (size_t i = 0; i < value.size(); ++i) { + const T tmp = value[i]; + memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T)); + } + } + + gguf_kv(const std::string & key, const std::string & value) + : key(key), is_array(false), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string.push_back(value); + } + + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string = value; + } + + const std::string & get_key() const { + return key; + } + + const enum gguf_type & get_type() const { + return type; + } + + size_t get_ne() const { + if (type == GGUF_TYPE_STRING) { + const size_t ne = data_string.size(); + GGML_ASSERT(is_array || ne == 1); + return ne; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + const size_t ne = data.size() / type_size; + GGML_ASSERT(is_array || ne == 1); + return ne; + } + + template + const T & get_val(const size_t i = 0) const { + GGML_ASSERT(type_to_gguf_type::value == type); + if constexpr (std::is_same::value) { + GGML_ASSERT(data_string.size() >= i+1); + return data_string[i]; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + GGML_ASSERT(data.size() >= (i+1)*type_size); + return reinterpret_cast(data.data())[i]; + } + + void cast(const enum gguf_type new_type) { + const size_t new_type_size = gguf_type_size(new_type); + GGML_ASSERT(data.size() % new_type_size == 0); + type = new_type; + } +}; + +struct gguf_tensor_info { + struct ggml_tensor t; // for holding the equivalent info + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` +}; + +struct gguf_context { + uint32_t version = GGUF_VERSION; + + std::vector kv; + std::vector info; + + size_t alignment = GGUF_DEFAULT_ALIGNMENT; + size_t offset = 0; // offset of `data` from beginning of file + size_t size = 0; // size of `data` in bytes + + void * data = nullptr; +}; + +struct gguf_reader { + FILE * file; + + gguf_reader(FILE * file) : file(file) {} + + template + bool read(T & dst) const { + return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + } + + template + bool read(std::vector & dst, const size_t n) const { + dst.resize(n); + for (size_t i = 0; i < dst.size(); ++i) { + if constexpr (std::is_same::value) { + bool tmp; + if (!read(tmp)) { + return false; + } + dst[i] = tmp; + } else { + if (!read(dst[i])) { + return false; + } + } + } + return true; + } + + bool read(bool & dst) const { + int8_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = tmp != 0; + return true; + } + + bool read(enum ggml_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = ggml_type(tmp); + return true; + } + + bool read(enum gguf_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = gguf_type(tmp); + return true; + } + + bool read(std::string & dst) const { + uint64_t size = -1; + if (!read(size)) { + return false; + } + dst.resize(size); + return fread(dst.data(), 1, dst.length(), file) == dst.length(); + } + + bool read(void * dst, const size_t size) const { + return fread(dst, 1, size, file) == size; + } +}; + +struct gguf_context * gguf_init_empty(void) { + return new gguf_context; +} + +template +bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) { + if (is_array) { + std::vector value; + try { + if (!gr.read(value, n)) { + return false; + } + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } + kv.emplace_back(key, value); + } else { + T value; + if (!gr.read(value)) { + return false; + } + kv.emplace_back(key, value); + } + return true; +} + +struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { + const struct gguf_reader gr(file); + struct gguf_context * ctx = new gguf_context; + + bool ok = true; + + // file magic + { + std::vector magic; + ok = ok && gr.read(magic, 4); + + if (!ok) { + fprintf(stderr, "%s: failed to read magic\n", __func__); + gguf_free(ctx); + return nullptr; + } + + for (uint32_t i = 0; i < magic.size(); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]); + gguf_free(ctx); + return nullptr; + } + } + } + + // header + int64_t n_kv = 0; + int64_t n_tensors = 0; + + if (ok && gr.read(ctx->version)) { + if (ctx->version == 1) { + fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + ok = false; + } + if (ctx->version > GGUF_VERSION) { + fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + __func__, ctx->version, GGUF_VERSION); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_tensors)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { + fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_kv)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { + fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); + ok = false; + } + } else { + ok = false; + } + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // KV pairs + { + for (int64_t i = 0; ok && i < n_kv; ++i) { + std::string key; + gguf_type type = gguf_type(-1); + bool is_array = false; + uint64_t n = 1; + + try { + ok = ok && gr.read(key); + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } + for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { + if (key == ctx->kv[j].key) { + fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + ok = false; + } + } + if (!ok) { + break; + } + + ok = ok && gr.read(type); + if (type == GGUF_TYPE_ARRAY) { + is_array = true; + ok = ok && gr.read(type); + ok = ok && gr.read(n); + } + if (!ok) { + break; + } + + switch (type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_STRING: ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_ARRAY: + default: + { + fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + ok = false; + } break; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv); + + const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT); + ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); + + if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { + fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + gguf_free(ctx); + return nullptr; + } + } + + // read the tensor info + for (int64_t i = 0; ok && i < n_tensors; ++i) { + struct gguf_tensor_info info; + + // tensor name + { + std::string name; + try { + ok = ok && gr.read(name); + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } + if (name.length() >= GGML_MAX_NAME) { + fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + ok = false; + break; + } + ggml_set_name(&info.t, name.c_str()); + + // make sure there are no duplicate tensor names + for (int64_t j = 0; ok && j < i; ++j) { + if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { + fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + ok = false; + break; + } + } + } + if (!ok) { + break; + } + + // tensor shape + { + uint32_t n_dims = -1; + ok = ok && gr.read(n_dims); + if (n_dims > GGML_MAX_DIMS) { + fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + __func__, info.t.name, n_dims, GGML_MAX_DIMS); + ok = false; + break; + } + for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) { + info.t.ne[j] = 1; + if (j < n_dims) { + ok = ok && gr.read(info.t.ne[j]); + } + + // check that all ne are non-negative + if (info.t.ne[j] < 0) { + fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + __func__, info.t.name, j, info.t.ne[j]); + ok = false; + break; + } + } + + // check that the total number of elements is representable + if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) || + (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || + (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { + + fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " + "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); + ok = false; + break; + } + } + if (!ok) { + break; + } + + // tensor type + { + ok = ok && gr.read(info.t.type); + + // check that tensor type is within defined range + if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", + __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + ok = false; + break; + } + const size_t type_size = ggml_type_size(info.t.type); + const int64_t blck_size = ggml_blck_size(info.t.type); + + // check that row size is divisible by block size + if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { + fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + "not a multiple of block size (%" PRId64 ")\n", + __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); + ok = false; + break; + } + + // calculate byte offsets given the tensor shape and type + info.t.nb[0] = type_size; + info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); + for (int j = 2; j < GGML_MAX_DIMS; ++j) { + info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1]; + } + } + if (!ok) { + break; + } + + // tensor data offset within buffer + ok = ok && gr.read(info.offset); + + ctx->info.push_back(info); + } + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); + + // we require the data section to be aligned, so take into account any padding + if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // store the current file offset - this is where the data section starts + ctx->offset = ftell(file); + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (size_t i = 0; i < ctx->info.size(); ++i) { + const gguf_tensor_info & ti = ctx->info[i]; + if (ti.offset != ctx->size) { + fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + __func__, ti.t.name, ti.offset, ctx->size); + fprintf(stderr, "%s: failed to read tensor data\n", __func__); + gguf_free(ctx); + return nullptr; + } + ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != nullptr) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (n_tensors )*ggml_tensor_overhead() : + (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + /*mem_size =*/ mem_size, + /*mem_buffer =*/ nullptr, + /*no_alloc =*/ params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + if (*params.ctx == nullptr) { + fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); + gguf_free(ctx); + return nullptr; + } + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = nullptr; + + if (!params.no_alloc) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != nullptr; + + // read the binary blob with the tensor data + ok = ok && gr.read(data->data, ctx->size); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (size_t i = 0; i < ctx->info.size(); ++i) { + const struct gguf_tensor_info & info = ctx->info[i]; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne); + + ok = ok && cur != nullptr; + + if (!ok) { + break; + } + + ggml_set_name(cur, info.t.name); + + // point the data member to the appropriate location in the binary blob using the tensor info + if (!params.no_alloc) { + cur->data = (char *) data->data + info.offset; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to create tensors\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = ggml_fopen(fname, "rb"); + + if (!file) { + fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); + return nullptr; + } + + struct gguf_context * result = gguf_init_from_file_impl(file, params); + fclose(file); + return result; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == nullptr) { + return; + } + delete ctx; +} + +const char * gguf_type_name(enum gguf_type type) { + auto it = GGUF_TYPE_NAME.find(type); + return it == GGUF_TYPE_NAME.end() ? nullptr : it->second; +} + +uint32_t gguf_get_version(const struct gguf_context * ctx) { + return ctx->version; +} + +size_t gguf_get_alignment(const struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(const struct gguf_context * ctx) { + return ctx->offset; +} + +int64_t gguf_get_n_kv(const struct gguf_context * ctx) { + return ctx->kv.size(); +} + +int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int64_t keyfound = -1; + + const int64_t n_kv = gguf_get_n_kv(ctx); + + for (int64_t i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].get_key().c_str(); +} + +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type(); +} + +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].is_array); + return ctx->kv[key_id].get_type(); +} + +const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); + return ctx->kv[key_id].data_string[i].c_str(); +} + +size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + + if (ctx->kv[key_id].type == GGUF_TYPE_STRING) { + return ctx->kv[key_id].data_string.size(); + } + + const size_t type_size = gguf_type_size(ctx->kv[key_id].type); + GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0); + return ctx->kv[key_id].data.size() / type_size; +} + +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val().c_str(); +} + +const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +int64_t gguf_get_n_tensors(const struct gguf_context * ctx) { + return ctx->info.size(); +} + +int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int64_t tensor_id = -1; + + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + for (int64_t i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensor_id = i; + break; + } + } + + return tensor_id; +} + +size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].offset; +} + +const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.name; +} + +enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.type; +} + +size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ggml_nbytes(&ctx->info[tensor_id].t); +} + +int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) { + const int64_t key_id = gguf_find_key(ctx, key); + if (key_id >= 0) { + ctx->kv.erase(ctx->kv.begin() + key_id); + } + return key_id; +} + +template +static void gguf_check_reserved_keys(const std::string & key, const T val) { + if (key == GGUF_KEY_GENERAL_ALIGNMENT) { + if constexpr (std::is_same::value) { + GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2"); + } else { + GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32"); + } + } +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, std::string(val)); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + const size_t nbytes = n*gguf_type_size(type); + std::vector tmp(nbytes); + if (!tmp.empty()) { + memcpy(tmp.data(), data, nbytes); + } + ctx->kv.emplace_back(key, tmp); + ctx->kv.back().cast(type); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + std::vector tmp(n); + for (size_t i = 0; i < n; ++i) { + tmp[i] = data[i]; + } + ctx->kv.emplace_back(key, tmp); +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) { + const int64_t n_kv = gguf_get_n_kv(src); + for (int64_t i = 0; i < n_kv; ++i) { + const struct gguf_kv & kv = src->kv[i]; + + if (!kv.is_array) { + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + continue; + } + + const size_t ne = kv.get_ne(); + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: { + gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne); + } break; + case GGUF_TYPE_STRING: { + std::vector tmp(ne); + for (size_t j = 0; j < ne; ++j) { + tmp[j] = kv.data_string[j].c_str(); + } + gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne); + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); + if (gguf_find_tensor(ctx, tensor->name) != -1) { + GGML_ABORT("duplicate tensor name: %s", tensor->name); + } + + struct gguf_tensor_info ti; + ti.t = *tensor; + ti.offset = ctx->info.empty() ? 0 : + ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment); + ctx->info.push_back(ti); +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + struct ggml_tensor * tensor = &ctx->info[tensor_id].t; + const size_t type_size = ggml_type_size(type); + const int64_t blck_size = ggml_blck_size(type); + + tensor->type = type; + GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type"); + + tensor->nb[0] = type_size; + tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1]; + } + + // update offsets + const int64_t n_tensors = gguf_get_n_tensors(ctx); + for (int64_t i = tensor_id + 1; i < n_tensors; ++i) { + ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment); + } +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + + ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const +} + +struct gguf_writer { + std::vector & buf; + + gguf_writer(std::vector & buf) : buf(buf) {} + + template + void write(const T & val) const { + for (size_t i = 0; i < sizeof(val); ++i) { + buf.push_back(reinterpret_cast(&val)[i]); + } + } + + void write(const std::vector & val) const { + buf.insert(buf.end(), val.begin(), val.end()); + } + + void write(const bool & val) const { + const int8_t val8 = val ? 1 : 0; + write(val8); + } + + void write(const std::string & val) const { + { + const uint64_t n = val.length(); + write(n); + } + for (size_t i = 0; i < val.length(); ++i) { + buf.push_back(reinterpret_cast(val.data())[i]); + } + } + + void write(const char * val) const { + write(std::string(val)); + } + + void write(const enum ggml_type & val) const { + write(int32_t(val)); + } + + void write(const enum gguf_type & val) const { + write(int32_t(val)); + } + + void write(const struct gguf_kv & kv) const { + const uint64_t ne = kv.get_ne(); + + write(kv.get_key()); + + if (kv.is_array) { + write(GGUF_TYPE_ARRAY); + write(kv.get_type()); + write(ne); + } else { + write(kv.get_type()); + } + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: { + write(kv.data); + } break; + case GGUF_TYPE_BOOL: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_STRING: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } + + void write_tensor_meta(const struct gguf_tensor_info & info) const { + write(info.t.name); + + const uint32_t n_dims = ggml_n_dims(&info.t); + write(n_dims); + + for (uint32_t j = 0; j < n_dims; ++j) { + write(info.t.ne[j]); + } + write(info.t.type); + write(info.offset); + } + + void pad(const size_t alignment) const { + while (buf.size() % alignment != 0) { + const int8_t zero = 0; + write(zero); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { + GGML_ASSERT(buf.size() - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t offset = buf.size(); + const size_t nbytes = ggml_nbytes(&info.t); + + buf.resize(offset + nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data() + offset, info.t.data, nbytes); + } + + pad(alignment); + } +}; + +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + const struct gguf_writer gw(buf); + + const int64_t n_kv = gguf_get_n_kv(ctx); + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + // write header + gw.write(GGUF_MAGIC[0]); + gw.write(GGUF_MAGIC[1]); + gw.write(GGUF_MAGIC[2]); + gw.write(GGUF_MAGIC[3]); + gw.write(ctx->version); + gw.write(n_tensors); + gw.write(n_kv); + + // write key-value pairs + for (int64_t i = 0; i < n_kv; ++i) { + gw.write(ctx->kv[i]); + } + + // write tensor info + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_meta(ctx->info[i]); + } + + // we require the data section to be aligned + gw.pad(ctx->alignment); + + if (only_meta) { + return; + } + + const size_t offset_data = gw.buf.size(); + + // write tensor data + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment); + } +} + +bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = ggml_fopen(fname, "wb"); + + if (!file) { + fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + return false; + } + + std::vector buf; + gguf_write_to_buf(ctx, buf, only_meta); + const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + fclose(file); + return ok; +} + +size_t gguf_get_meta_size(const struct gguf_context * ctx) { + // only return size + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + return buf.size(); +} + +void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + memcpy(data, buf.data(), buf.size()); +} diff --git a/gguf-py/README.md b/gguf-py/README.md index 24af96a17..37a75923b 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -15,13 +15,13 @@ pip install gguf [examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. -[scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. +[gguf/scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. -[scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. +[gguf/scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. -[scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. +[gguf/scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. -[scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. +[gguf/scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. ## Development Maintainers who participate in development of this package are advised to install it in editable mode: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 273370370..cf05bf47e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -102,6 +102,8 @@ class Keys: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -242,6 +244,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() PHI2 = auto() PHI3 = auto() + PHIMOE = auto() PLAMO = auto() CODESHELL = auto() ORION = auto() @@ -255,6 +258,7 @@ class MODEL_ARCH(IntEnum): MAMBA = auto() XVERSE = auto() COMMAND_R = auto() + COHERE2 = auto() DBRX = auto() OLMO = auto() OLMO2 = auto() @@ -312,6 +316,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -424,6 +429,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", @@ -437,6 +443,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.OLMO2: "olmo2", @@ -496,6 +503,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -934,6 +942,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PHIMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.CODESHELL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, @@ -1136,6 +1162,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_Q_NORM, ], + MODEL_ARCH.COHERE2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.DBRX: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1276,6 +1314,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1576,6 +1615,11 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3023b539a..4a0a65e3c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -715,6 +716,12 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/scripts/__init__.py b/gguf-py/gguf/scripts/__init__.py similarity index 100% rename from gguf-py/scripts/__init__.py rename to gguf-py/gguf/scripts/__init__.py diff --git a/gguf-py/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py similarity index 100% rename from gguf-py/scripts/gguf_convert_endian.py rename to gguf-py/gguf/scripts/gguf_convert_endian.py diff --git a/gguf-py/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py similarity index 100% rename from gguf-py/scripts/gguf_dump.py rename to gguf-py/gguf/scripts/gguf_dump.py diff --git a/gguf-py/scripts/gguf_hash.py b/gguf-py/gguf/scripts/gguf_hash.py similarity index 100% rename from gguf-py/scripts/gguf_hash.py rename to gguf-py/gguf/scripts/gguf_hash.py diff --git a/gguf-py/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py similarity index 100% rename from gguf-py/scripts/gguf_new_metadata.py rename to gguf-py/gguf/scripts/gguf_new_metadata.py diff --git a/gguf-py/scripts/gguf_set_metadata.py b/gguf-py/gguf/scripts/gguf_set_metadata.py similarity index 100% rename from gguf-py/scripts/gguf_set_metadata.py rename to gguf-py/gguf/scripts/gguf_set_metadata.py diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7009a11d4..7616c468a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -55,7 +55,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -68,7 +68,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -108,7 +108,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -152,7 +152,7 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert @@ -165,7 +165,7 @@ class TensorNameMap: # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert @@ -179,7 +179,7 @@ class TensorNameMap: # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j @@ -197,7 +197,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert @@ -242,7 +242,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -265,7 +265,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral - "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe "model.layers.{bid}.mlp.gate", # qwen2moe olmoe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx @@ -276,6 +276,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox @@ -306,10 +310,11 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -338,10 +343,11 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -383,6 +389,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 9c3956256..92d7f22ec 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,12 +1,11 @@ [tool.poetry] name = "gguf" -version = "0.13.0" +version = "0.14.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ {include = "gguf"}, {include = "gguf/py.typed"}, - {include = "scripts"}, ] readme = "README.md" homepage = "https://ggml.ai" @@ -33,7 +32,7 @@ requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint" -gguf-dump = "scripts:gguf_dump_entrypoint" -gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint" -gguf-new-metadata = "scripts:gguf_new_metadata_entrypoint" +gguf-convert-endian = "gguf.scripts:gguf_convert_endian_entrypoint" +gguf-dump = "gguf.scripts:gguf_dump_entrypoint" +gguf-set-metadata = "gguf.scripts:gguf_set_metadata_entrypoint" +gguf-new-metadata = "gguf.scripts:gguf_new_metadata_entrypoint" diff --git a/include/llama-cpp.h b/include/llama-cpp.h index daa04d4d8..11306b17f 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -9,7 +9,7 @@ #include "llama.h" struct llama_model_deleter { - void operator()(llama_model * model) { llama_free_model(model); } + void operator()(llama_model * model) { llama_model_free(model); } }; struct llama_context_deleter { @@ -20,6 +20,11 @@ struct llama_sampler_deleter { void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } }; +struct llama_lora_adapter_deleter { + void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); } +}; + typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; +typedef std::unique_ptr llama_lora_adapter_ptr; diff --git a/include/llama.h b/include/llama.h index a4abf395b..0295a51fb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -34,7 +34,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -// TODO: use everywhere in the implementation #define LLAMA_TOKEN_NULL -1 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' @@ -105,6 +104,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, }; enum llama_rope_type { @@ -385,6 +385,7 @@ extern "C" { } llama_chat_message; // lora adapter + // TODO: rename to llama_adapter_lora struct llama_lora_adapter; // Helpers for getting default parameters @@ -412,11 +413,19 @@ extern "C" { // Call once at the end of the program - currently only used for MPI LLAMA_API void llama_backend_free(void); - LLAMA_API struct llama_model * llama_load_model_from_file( + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); + + LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - LLAMA_API void llama_free_model(struct llama_model * model); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); // TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_new_context_with_model( @@ -501,14 +510,19 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); + // + // Adapters + // + // Load a LoRA adapter from file - // The loaded adapter will be associated to the given model, and will be free when the model is deleted + // TODO: rename to llama_adapter_lora_init LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( struct llama_model * model, const char * path_lora); // Add a loaded LoRA adapter to given context // This will not modify model's weight + // TODO: rename to llama_set_adapter_lora LLAMA_API int32_t llama_lora_adapter_set( struct llama_context * ctx, struct llama_lora_adapter * adapter, @@ -516,16 +530,18 @@ extern "C" { // Remove a specific LoRA adapter from given context // Return -1 if the adapter is not present in the context + // TODO: rename to llama_rm_adapter_lora LLAMA_API int32_t llama_lora_adapter_remove( struct llama_context * ctx, struct llama_lora_adapter * adapter); // Remove all LoRA adapters from given context - LLAMA_API void llama_lora_adapter_clear( - struct llama_context * ctx); + // TODO: rename to llama_clear_adapter_lora + LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx); // Manually free a LoRA adapter // Note: loaded adapters will be free when the associated model is deleted + // TODO: rename to llama_adapter_lora_free LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); // Apply a loaded control vector to a llama_context, or if data is NULL, clear @@ -534,6 +550,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. + // TODO: rename to llama_adapter_cvec_apply LLAMA_API int32_t llama_control_vector_apply( struct llama_context * lctx, const float * data, @@ -546,6 +563,8 @@ extern "C" { // KV cache // + // TODO: remove llama_kv_cache_view_* API + // Information associated with an individual cell in the KV cache view. struct llama_kv_cache_view_cell { // The position for this cell. Takes KV cache shifts into account. @@ -592,8 +611,11 @@ extern "C" { LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) + // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + /// + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -663,6 +685,9 @@ extern "C" { struct llama_context * ctx, llama_seq_id seq_id); + // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache + // how to avoid this? + // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() diff --git a/media/llama-leader.jpeg b/media/llama-leader.jpeg deleted file mode 100644 index 0b4e6e1cf..000000000 Binary files a/media/llama-leader.jpeg and /dev/null differ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b4ac38bbf..a0921f1a9 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -e6d93f40dffe8733d5d72f1d8fa6b3ca27ae899f +c8bd0fee71dc8328d93be301bbee06bc10d30429 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2d3ea0994..aeb75bf3e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,9 +9,21 @@ llama_add_compile_flags() add_library(llama ../include/llama.h llama.cpp - llama-vocab.cpp + llama-adapter.cpp + llama-arch.cpp + llama-batch.cpp + llama-chat.cpp + llama-context.cpp llama-grammar.cpp + llama-hparams.cpp + llama-impl.cpp + llama-kv-cache.cpp + llama-mmap.cpp + llama-model-loader.cpp + llama-model.cpp + llama-quant.cpp llama-sampling.cpp + llama-vocab.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp new file mode 100644 index 000000000..d4879b778 --- /dev/null +++ b/src/llama-adapter.cpp @@ -0,0 +1,346 @@ +#include "llama-adapter.h" + +#include "llama-model.h" + +#include +#include +#include +#include + +// vec + +struct ggml_tensor * llama_control_vector::tensor_for(int il) const { + if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { + return nullptr; + } + + return tensors[il]; +} + +struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + + return cur; +} + +static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { + const auto & hparams = model.hparams; + + GGML_ASSERT(cvec.tensors.empty()); + GGML_ASSERT(cvec.ctxs.empty()); + GGML_ASSERT(cvec.bufs.empty()); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + cvec.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // make tensors + cvec.tensors.reserve(hparams.n_layer); + cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 + for (size_t il = 1; il < hparams.n_layer; il++) { + ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il); + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); + return false; + } + ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + cvec.tensors.push_back(tensor); + } + + // allocate tensors / buffers and zero + cvec.bufs.reserve(ctx_map.size()); + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + cvec.bufs.emplace_back(buf); + } + + return true; +} + +int32_t llama_control_vector_apply( + struct llama_control_vector & cvec, + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + const auto & hparams = model.hparams; + + if (data == nullptr) { + // disable the current control vector (but leave allocated for later) + cvec.layer_start = -1; + cvec.layer_end = -1; + return 0; + } + + if (n_embd != (int) hparams.n_embd) { + LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); + return 1; + } + + if (cvec.tensors.empty()) { + if (!llama_control_vector_init(cvec, model)) { + return 1; + } + } + + cvec.layer_start = il_start; + cvec.layer_end = il_end; + + for (size_t il = 1; il < hparams.n_layer; il++) { + assert(cvec.tensors[il] != nullptr); + + const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present + if (off + n_embd <= len) { + ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il])); + } + } + + return 0; +} + +// lora + +llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { + const std::string name(w->name); + + const auto pos = ab_map.find(name); + if (pos != ab_map.end()) { + return &pos->second; + } + + return nullptr; +} + +void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { + delete adapter; +} + +static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) { + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); + + ggml_context * ctx_init; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &ctx_init, + }; + + gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) }; + if (!ctx_gguf) { + throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora)); + } + + ggml_context_ptr ctx { ctx_init }; + + // check metadata + { + auto get_kv_str = [&](const std::string & key) -> std::string { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + }; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + }; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE)); + if (general_type != "adapter") { + throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type); + } + + auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); + auto general_arch = llm_arch_from_string(general_arch_str); + if (general_arch != model.arch) { + throw std::runtime_error("model arch and LoRA arch mismatch"); + } + + auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE)); + if (adapter_type != "lora") { + throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type); + } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + } + + int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); + + // contexts for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // add a new context + struct ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * buft_ctx = ggml_init(params); + if (!buft_ctx) { + return nullptr; + } + ctx_map[buft] = buft_ctx; + adapter.ctxs.emplace_back(buft_ctx); + return buft_ctx; + }; + return it->second; + }; + + // bundle lora_a and lora_b into pairs + std::map ab_map; + auto str_endswith = [](const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; + }; + + for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) { + std::string name(cur->name); + if (str_endswith(name, ".lora_a")) { + replace_all(name, ".lora_a", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(cur, nullptr); + } else { + ab_map[name].a = cur; + } + } else if (str_endswith(name, ".lora_b")) { + replace_all(name, ".lora_b", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(nullptr, cur); + } else { + ab_map[name].b = cur; + } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; + } else { + throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); + } + } + + // add tensors + for (auto & it : ab_map) { + const std::string & name = it.first; + llama_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); + + if (!w.a || !w.b) { + throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); + } + + // device buft and device ctx + auto * model_tensor = llama_model_get_tensor(model, name.c_str()); + if (!model_tensor) { + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); + } + + struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + // validate tensor shape + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } + } + + // save tensor to adapter + struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_set_name(tensor_a, w.a->name); + ggml_set_name(tensor_b, w.b->name); + adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + } + + // allocate tensors / buffers and zero + { + adapter.ctxs.reserve(ctx_map.size()); + adapter.bufs.reserve(ctx_map.size()); + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx_dev = it.second; + ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) }; + if (!buf) { + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); + } + LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0); + adapter.bufs.emplace_back(std::move(buf)); + } + } + + // set tensor data + { + llama_file gguf_file(path_lora, "rb"); + std::vector read_buf; + auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) { + size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); + size_t size = ggml_nbytes(orig); + read_buf.resize(size); + gguf_file.seek(offs, SEEK_SET); + gguf_file.read_raw(read_buf.data(), size); + ggml_backend_tensor_set(dev, read_buf.data(), 0, size); + }; + for (auto & it : adapter.ab_map) { + auto orig = ab_map[it.first]; + auto dev = it.second; + set_tensor(orig.a, dev.a); + set_tensor(orig.b, dev.b); + } + } + + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); +} + +struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { + struct llama_lora_adapter * adapter = new llama_lora_adapter(); + + try { + llama_lora_adapter_init_impl(*model, path_lora, *adapter); + return adapter; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + + delete adapter; + } + + return nullptr; +} diff --git a/src/llama-adapter.h b/src/llama-adapter.h new file mode 100644 index 000000000..3448656b1 --- /dev/null +++ b/src/llama-adapter.h @@ -0,0 +1,73 @@ +#pragma once + +#include "llama-impl.h" +#include "llama-hparams.h" + +#include "ggml-cpp.h" + +#include +#include + +// +// llama_adapter_cvec +// + +// TODO: rename to llama_adapter_cvec +struct llama_control_vector { + std::vector ctxs; + std::vector bufs; + + std::vector tensors; // per layer + + int32_t layer_start = -1; + int32_t layer_end = -1; + + struct ggml_tensor * tensor_for(int il) const; + + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; +}; + +int32_t llama_control_vector_apply( + struct llama_control_vector & cvec, + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +// +// llama_adapter_lora +// + +// TODO: rename to llama_adapter_lora_weight +struct llama_lora_weight { + struct ggml_tensor * a = nullptr; + struct ggml_tensor * b = nullptr; + + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_lora_weight() = default; + llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} +}; + +// TODO: rename to llama_adapter_lora +struct llama_lora_adapter { + // map tensor name to lora_a_b + std::unordered_map ab_map; + + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_lora_adapter() = default; + ~llama_lora_adapter() = default; + + llama_lora_weight * get_weight(struct ggml_tensor * w); +}; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp new file mode 100644 index 000000000..eef66ed31 --- /dev/null +++ b/src/llama-arch.cpp @@ -0,0 +1,1456 @@ +#include "llama-arch.h" + +#include "llama-impl.h" + +#include + +static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, +}; + +static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, + { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, + { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, + { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, + { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, + { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_SPLIT_NO, "split.no" }, + { LLM_KV_SPLIT_COUNT, "split.count" }, + { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, + + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + + { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, + { LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" }, + + { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, + { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, +}; + +static const std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_BAICHUAN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GROK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + }, + }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + }, + }, + { + LLM_ARCH_STARCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, + { + LLM_ARCH_NOMIC_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JINA_BERT_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + }, + }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_STABLELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_QWEN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_PHI2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_PLAMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_CODESHELL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ORION, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_INTERNLM2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MINICPM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + }, + }, + { + LLM_ARCH_MINICPM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GEMMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GEMMA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_STARCODER2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, + { + LLM_ARCH_XVERSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_COMMAND_R, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_DBRX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ARCTIC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DEEPSEEK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_T5ENCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JAIS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_NEMOTRON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_EXAONE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_RWKV6, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + }, + }, + { + LLM_ARCH_GRANITE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_WAVTOKENIZER_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_CONV1D, "conv1d" }, + { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, + { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, + { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, + { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, + { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, + { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, + { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, +}; + +static const std::map LLM_TENSOR_INFOS = { + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, + {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, + {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, + {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + // this tensor is loaded for T5, but never used + {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, +}; + +LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} + +std::string LLM_KV::operator()(llm_kv kv) const { + return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); +} + +std::string LLM_TN_IMPL::str() const { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + return "__missing__"; + } + + std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; +} + +const char * llm_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +llm_arch llm_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { + return LLM_TENSOR_INFOS.at(tensor); +} diff --git a/src/llama-arch.h b/src/llama-arch.h new file mode 100644 index 000000000..2e5f97b77 --- /dev/null +++ b/src/llama-arch.h @@ -0,0 +1,396 @@ +#pragma once + +#include "ggml.h" // ggml_op + +#include + +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_DECI, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN2VL, + LLM_ARCH_PHI2, + LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, + LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, + LLM_ARCH_COMMAND_R, + LLM_ARCH_COHERE2, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, + LLM_ARCH_OLMO2, + LLM_ARCH_OLMOE, + LLM_ARCH_OPENELM, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK, + LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_BITNET, + LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, + LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_EXAONE, + LLM_ARCH_RWKV6, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, + LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_UNKNOWN, +}; + +enum llm_kv { + LLM_KV_GENERAL_TYPE, + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_VERSION, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_VOCAB_SIZE, + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_FEATURES_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_POOLING_TYPE, + LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_GROUPNORM_EPS, + LLM_KV_ATTENTION_GROUPNORM_GROUPS, + LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_SECTIONS, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + + LLM_KV_SPLIT_NO, + LLM_KV_SPLIT_COUNT, + LLM_KV_SPLIT_TENSORS_COUNT, + + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_DT_B_C_RMS, + + LLM_KV_WKV_HEAD_SIZE, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, + + LLM_KV_POSNET_EMBEDDING_LENGTH, + LLM_KV_POSNET_BLOCK_COUNT, + + LLM_KV_CONVNEXT_EMBEDDING_LENGTH, + LLM_KV_CONVNEXT_BLOCK_COUNT, + + // deprecated: + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + LLM_TENSOR_CHANNEL_MIX_VALUE, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CONV1D, + LLM_TENSOR_CONVNEXT_DW, + LLM_TENSOR_CONVNEXT_NORM, + LLM_TENSOR_CONVNEXT_PW1, + LLM_TENSOR_CONVNEXT_PW2, + LLM_TENSOR_CONVNEXT_GAMMA, + LLM_TENSOR_POS_NET_CONV1, + LLM_TENSOR_POS_NET_CONV2, + LLM_TENSOR_POS_NET_NORM, + LLM_TENSOR_POS_NET_NORM1, + LLM_TENSOR_POS_NET_NORM2, + LLM_TENSOR_POS_NET_ATTN_NORM, + LLM_TENSOR_POS_NET_ATTN_Q, + LLM_TENSOR_POS_NET_ATTN_K, + LLM_TENSOR_POS_NET_ATTN_V, + LLM_TENSOR_POS_NET_ATTN_OUT, +}; + +enum llm_tensor_layer { + LLM_TENSOR_LAYER_INPUT, + LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_OUTPUT, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch); + + llm_arch arch; + + std::string operator()(llm_kv kv) const; +}; + +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN_IMPL { + const llm_arch arch; + const llm_tensor tensor; + const char * const suffix; + const int bid; + const int xid; + + std::string str() const; + + operator std::string() const { + return str(); + } + + friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) { + return str == tn.str(); + } + + friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) { + return str != tn.str(); + } +}; + +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { + return { arch, tensor, suffix, bid, xid }; + } + + LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const { + return { arch, tensor, nullptr, bid, xid }; + } +}; + + +struct llm_tensor_info { + llm_tensor_layer layer; + ggml_op op; +}; + +const char * llm_arch_name(llm_arch arch); + +llm_arch llm_arch_from_string(const std::string & name); + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp new file mode 100644 index 000000000..01d5ca57f --- /dev/null +++ b/src/llama-batch.cpp @@ -0,0 +1,368 @@ +#include "llama-batch.h" + +#include +#include + +llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + /*equal_seqs =*/ true, + /*n_tokens =*/ 0, + /*n_seq_tokens =*/ 0, + /*n_seqs =*/ 0, + /*token =*/ !has_embd ? ubatch_token.data() : nullptr, + /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, + /*pos =*/ ubatch_pos.data(), + /*n_seq_id =*/ ubatch_n_seq_id.data(), + /*seq_id =*/ ubatch_seq_id.data(), + /*output =*/ ubatch_output.data(), + }; + return ubatch; +} + +void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + (n_embd * (ubatch.n_tokens + i)), + batch->embd + (n_embd * ids[seq.offset + i]), + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + (n_embd * seq.offset); + } + } else { + ubatch.embd = nullptr; + } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; + } + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } + } else { + // simple split + if (batch->n_seq_id) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; + } + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); +} + +llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (simple_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + return; + } + std::sort(ids.begin(), ids.end(), + [&batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; + seq.push_back(new_seq); + last_seq = &seq.back(); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); +} + +llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { + batch = in_batch; + GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i + p0; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } +} + +// +// interface implementation +// + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; +} + +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} diff --git a/src/llama-batch.h b/src/llama-batch.h new file mode 100644 index 000000000..773c3808b --- /dev/null +++ b/src/llama-batch.h @@ -0,0 +1,88 @@ +#pragma once + +#include "llama.h" + +#include +#include + +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + // TODO: whole_seqs for embeddings? + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + +struct llama_sbatch_seq { + int32_t n_seq_id; + + llama_seq_id * seq_id; + + size_t offset; + size_t length; +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); + + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch); + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch); + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch); + + void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); +}; + +// temporary allocate memory for the input batch if needed +struct llama_batch_allocr { + struct llama_batch batch; + + std::array seq_id_0 = { 0 }; // default sequence id + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); +}; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp new file mode 100644 index 000000000..1347ec156 --- /dev/null +++ b/src/llama-chat.cpp @@ -0,0 +1,578 @@ +#include "llama-chat.h" + +#include "llama.h" + +#include +#include + +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, + { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, + { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, +}; + +llm_chat_template llm_chat_template_from_str(const std::string & name) { + return LLM_CHAT_TEMPLATES.at(name); +} + +llm_chat_template llm_chat_detect_template(const std::string & tmpl) { + try { + return llm_chat_template_from_str(tmpl); + } catch (const std::out_of_range &) { + // ignore + } + + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + return LLM_CHAT_TEMPLATE_FALCON_3; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGML_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGML_4; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { + return LLM_CHAT_TEMPLATE_GIGACHAT; + } else if (tmpl_contains("<|role_start|>")) { + return LLM_CHAT_TEMPLATE_MEGREZ; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST] " << content << "[/INST]"; + } + else { + ss << " " << content << "
"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + // [variant] trim spaces from the input message + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << content << ""; + is_inside_turn = false; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { + // Falcon 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { + // OrionStarAI/Orion-14B-Chat + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message support, we will merge it with user prompt + system_prompt = message->content; + continue; + } else if (role == "user") { + ss << "Human: "; + if (!system_prompt.empty()) { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << message->content << "\n\nAssistant: "; + } else { + ss << message->content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { + // openchat/openchat-3.5-0106, + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "<|end_of_turn|>"; + } else { + role[0] = toupper(role[0]); + ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; + } + } + if (add_ass) { + ss << "GPT4 Correct Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // Orca-Vicuna variant uses a system prefix + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + ss << "SYSTEM: " << message->content << "\n"; + } else { + ss << message->content << "\n\n"; + } + } else if (role == "user") { + ss << "USER: " << message->content << "\n"; + } else if (role == "assistant") { + ss << "ASSISTANT: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "ASSISTANT:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { + // deepseek-ai/deepseek-coder-33b-instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content; + } else if (role == "user") { + ss << "### Instruction:\n" << message->content << "\n"; + } else if (role == "assistant") { + ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; + } + } + if (add_ass) { + ss << "### Response:\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << LU8("<用户>"); + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { + // DeepSeek-V2 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { + // this template requires the model to have "\n\n" as EOT token + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "User: " << message->content << "\n\nAssistant:"; + } else { + ss << message->content << "\n\n"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { + // IBM Granite template + for (const auto & message : chat) { + std::string role(message->role); + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + if (role == "assistant_tool_call") { + ss << "<|tool_call|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { + // GigaChat template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + // Handle system message if present + if (has_system) { + ss << "" << chat[0]->content << "<|message_sep|>"; + } else { + ss << ""; + } + + // Process remaining messages + for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>" + << "available functions<|role_sep|>[]<|message_sep|>"; + } else if (role == "assistant") { + ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << "assistant<|role_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) { + // Megrez template + for (auto message : chat) { + std::string role(message->role); + ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>"; + } + + if (add_ass) { + ss << "<|role_start|>assistant<|role_end|>"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +// public interface + +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} + diff --git a/src/llama-chat.h b/src/llama-chat.h new file mode 100644 index 000000000..3a4d07ce3 --- /dev/null +++ b/src/llama-chat.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, + LLM_CHAT_TEMPLATE_FALCON_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGML_3, + LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GIGACHAT, + LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +struct llama_chat_message; + +llm_chat_template llm_chat_template_from_str(const std::string & name); + +llm_chat_template llm_chat_detect_template(const std::string & tmpl); + +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass); diff --git a/src/llama-context.cpp b/src/llama-context.cpp new file mode 100644 index 000000000..38a55fb2c --- /dev/null +++ b/src/llama-context.cpp @@ -0,0 +1,1771 @@ +#include "llama-context.h" + +#include +#include +#include +#include + +void llama_set_k_shift(struct llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); + + int32_t * data = (int32_t *) lctx.inp_K_shift->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].delta; + } +} + +void llama_set_s_copy(struct llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].src; + } +} + +// llama input + +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; +} + +void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { + // + // set input data + // + + const auto & hparams = lctx.model.hparams; + const auto & cparams = lctx.cparams; + const auto & kv_self = lctx.kv_self; + + if (ubatch.token) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); + } + + if (ubatch.embd) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + } + + if (ubatch.pos && lctx.inp_pos) { + const int64_t n_tokens = ubatch.n_tokens; + auto n_pos = lctx.n_pos_per_token; + ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos)); + } + + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + + if (!lctx.inp_out_ids) { + LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); + int32_t * data = (int32_t *) lctx.inp_out_ids->data; + + if (lctx.n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch.output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch.output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(lctx.n_outputs == n_outputs); + } else if (lctx.n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(lctx.n_outputs == 0); + } + } + } + + GGML_ASSERT( + // (!a || b) is a logical implication (a -> b) + // !hparams.causal_attn -> !cparams.causal_attn + (hparams.causal_attn || !cparams.causal_attn) && + "causal attention is not supported by this model" + ); + + if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn && !lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + + float * data = nullptr; + float * data_swa = nullptr; + + if (lctx.inp_KQ_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + data = (float *) lctx.inp_KQ_mask->data; + } + + if (lctx.inp_KQ_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } + + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + } + + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } + } else { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + // when using kv cache, the mask needs to match the kv cache size + const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + + float * data = (float *) lctx.inp_KQ_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_mean); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); + + float * data = (float *) lctx.inp_mean->data; + memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); + + std::vector sum(n_tokens, 0); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch.n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } + + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + + if (kv_self.recurrent) { + const int64_t n_kv = kv_self.n; + + if (lctx.inp_s_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); + float * data = (float *) lctx.inp_s_mask->data; + + // clear unused states + for (int i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + + data[i] = (float) (kv_cell.src >= 0); + + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; + } + } + } + + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; + } + } + } + } + + if (lctx.inp_pos_bucket) { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + + if (!lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } else { + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } + } + + if (!lctx.is_encoding && lctx.inp_embd_enc) { + assert(lctx.inp_embd_enc->type == GGML_TYPE_F32); + assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size()); + + ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); + } + + if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { + const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + float * data = (float *) lctx.inp_KQ_mask_cross->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_output_enc; ++i) { + float f = -INFINITY; + for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[j][s]; + if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_output_enc; ++j) { + data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; + } + } + } + } +} + +// llama output + +size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = hparams.n_vocab; + const auto n_embd = hparams.n_embd; + + // TODO: use a per-batch flag for logits presence instead + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + + const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; + const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (lctx.output_ids.empty()) { + // init, never resized afterwards + lctx.output_ids.resize(n_batch); + } + + const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0; + const size_t new_size = (logits_size + embd_size) * sizeof(float); + + // alloc only when more than the current capacity is required + // TODO: also consider shrinking the buffer + if (!lctx.buf_output || prev_size < new_size) { + if (lctx.buf_output) { +#ifndef NDEBUG + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + lctx.buf_output = nullptr; + lctx.logits = nullptr; + lctx.embd = nullptr; + } + + auto * buft = ggml_backend_cpu_buffer_type(); + // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory + auto * output_dev = lctx.model.dev_output.dev; + auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; + if (output_dev_host_buft) { + buft = output_dev_host_buft; + } + lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); + if (lctx.buf_output == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); + return 0; + } + } + + float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get()); + + lctx.logits = has_logits ? output_base : nullptr; + lctx.embd = has_embd ? output_base + logits_size : nullptr; + + lctx.output_size = n_outputs_max; + lctx.logits_size = logits_size; + lctx.embd_size = embd_size; + + // set all ids as invalid (negative) + std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1); + + ggml_backend_buffer_clear(lctx.buf_output.get(), 0); + + lctx.n_outputs = 0; + + return n_outputs_max; +} + +void llama_output_reorder(struct llama_context & ctx) { + std::vector & out_ids = ctx.sbatch.out_ids; + if (!out_ids.empty()) { + const uint32_t n_vocab = ctx.model.hparams.n_vocab; + const uint32_t n_embd = ctx.model.hparams.n_embd; + + const int32_t n_outputs = ctx.n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx.logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]); + } + } + if (ctx.embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx.output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + +// +// interface implementation +// + +void llama_free(struct llama_context * ctx) { + delete ctx; +} + +uint32_t llama_n_ctx(const struct llama_context * ctx) { + return ctx->cparams.n_ctx; +} + +uint32_t llama_n_batch(const struct llama_context * ctx) { + return ctx->cparams.n_batch; +} + +uint32_t llama_n_ubatch(const struct llama_context * ctx) { + return ctx->cparams.n_ubatch; +} + +uint32_t llama_n_seq_max(const struct llama_context * ctx) { + return ctx->kv_self.size; +} + +const struct llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + +void llama_attach_threadpool( + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + ctx->threadpool = threadpool; + ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; +} + +void llama_detach_threadpool(struct llama_context * ctx) { + ctx->threadpool = nullptr; + ctx->threadpool_batch = nullptr; +} + +void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + +int32_t llama_n_threads(struct llama_context * ctx) { + return ctx->cparams.n_threads; +} + +int32_t llama_n_threads_batch(struct llama_context * ctx) { + return ctx->cparams.n_threads_batch; +} + +void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; + + for (auto & backend : ctx->backends) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data); + } + } +} + +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + +void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { + ctx->cparams.causal_attn = causal_attn; +} + +void llama_synchronize(struct llama_context * ctx) { + ggml_backend_sched_synchronize(ctx->sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (ctx->n_queued_tokens == 1) { + if (!ctx->cparams.no_perf) { + ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us; + } + ctx->n_eval++; + } else if (ctx->n_queued_tokens > 1) { + if (!ctx->cparams.no_perf) { + ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us; + } + ctx->n_p_eval += ctx->n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + ctx->n_queued_tokens = 0; + ctx->t_compute_start_us = 0; +} + +float * llama_get_logits(struct llama_context * ctx) { + llama_synchronize(ctx); + + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(*ctx); + + return ctx->logits; +} + +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + + llama_synchronize(ctx); + + try { + if (ctx->logits == nullptr) { + throw std::runtime_error("no logits"); + } + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + + return ctx->logits + j*ctx->model.hparams.n_vocab; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_get_embeddings(struct llama_context * ctx) { + llama_synchronize(ctx); + + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(*ctx); + + return ctx->embd; +} + +float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + + llama_synchronize(ctx); + + try { + if (ctx->embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + + return ctx->embd + j*ctx->model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + + auto it = ctx->embd_seq.find(seq_id); + if (it == ctx->embd_seq.end()) { + return nullptr; + } + + return it->second.data(); +} + +// llama state API + +// deprecated +size_t llama_get_state_size(struct llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst, -1); +} + +// deprecated +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src, -1); +} + +// deprecated +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} + +// TODO: replace all non-fatal assertions with returned errors or exceptions +struct llama_data_write { + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_write() = default; + + void write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); + } + + void write_model_info(const struct llama_context * ctx) { + const std::string arch_str = llm_arch_name(ctx->model.arch); + write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + //void write_rng(const std::mt19937 & rng) { + // std::ostringstream rng_ss; + // rng_ss << rng; + + // const std::string & rng_str = rng_ss.str(); + + // write_string(rng_str); + //} + + void write_output_ids(struct llama_context * ctx) { + llama_output_reorder(*ctx); + + const uint32_t n_outputs = ctx->n_outputs; + + std::vector output_pos; + + const size_t n_batch = ctx->cparams.n_batch; + const auto & output_ids = ctx->output_ids; + + GGML_ASSERT(n_outputs <= ctx->output_size); + + output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch; ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT((uint32_t) pos < n_outputs); + output_pos[pos] = i; + } + } + + write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + write(output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + void write_logits(const struct llama_context * ctx) { + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); + + write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + write(ctx->logits, logits_size * sizeof(float)); + } + } + + void write_embeddings(const struct llama_context * ctx) { + const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd); + + write(&embeddings_size, sizeof(embeddings_size)); + + if (embeddings_size) { + write(ctx->embd, embeddings_size * sizeof(float)); + } + } + + void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = kv_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + write(&seq_id, sizeof(seq_id)); + } + } + } + } + } + + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_hparams & hparams = ctx->model.hparams; + + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; + const uint32_t n_layer = hparams.n_layer; + + write(&v_trans, sizeof(v_trans)); + write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!kv_self.v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + write_tensor_data(kv_self.v_l[il], src_offset, buf_size); + } + } + } + } + } + + void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->kv_self; + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } + } + } + if (cell_range_begin != kv_self.size) { + cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + write(&cell_count, sizeof(cell_count)); + + write_kv_cache_meta(kv_self, cell_ranges, seq_id); + write_kv_cache_data(ctx, cell_ranges); + } +}; + +struct llama_data_read { + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + virtual size_t get_size_read() = 0; + virtual ~llama_data_read() = default; + + void read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); + } + + // validate model information + void read_model_info(const struct llama_context * ctx) { + const std::string cur_arch_str = llm_arch_name(ctx->model.arch); + + std::string arch_str; + read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + //void read_rng(std::mt19937 & rng) { + // std::string rng_str; + // read_string(rng_str); + + // std::istringstream rng_ss(rng_str); + // rng_ss >> rng; + + // if (rng_ss.fail()) { + // throw std::runtime_error("failed to load RNG state"); + // } + //} + + void read_output_ids(struct llama_context * ctx) { + std::vector output_pos; + + uint32_t n_outputs; + read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > llama_output_reserve(*ctx, n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + if (n_outputs) { + output_pos.resize(n_outputs); + read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= ctx->cparams.n_batch) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); + } + ctx->output_ids[id] = i; + } + + ctx->n_outputs = n_outputs; + } + } + + void read_logits(struct llama_context * ctx) { + uint64_t logits_size; + read_to(&logits_size, sizeof(logits_size)); + + if (ctx->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + read_to(ctx->logits, logits_size * sizeof(float)); + } + } + + void read_embeddings(struct llama_context * ctx) { + uint64_t embeddings_size; + read_to(&embeddings_size, sizeof(embeddings_size)); + + if (ctx->embd_size < embeddings_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embeddings_size) { + read_to(ctx->embd, embeddings_size * sizeof(float)); + } + } + + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { + struct llama_kv_cache & kv_self = ctx->kv_self; + + if (dest_seq_id != -1) { + // single sequence + + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!llama_kv_cache_find_slot(kv_self, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + llama_kv_cache_clear(kv_self); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; + } + + cell.seq_id.insert(seq_id); + + if (kv_self.recurrent) { + int32_t & tail = kv_self.cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + } + + kv_self.head = 0; + kv_self.used = cell_count; + } + + if (kv_self.recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = kv_self.head + i; + // make sure the recurrent states will keep their restored state + kv_self.cells[cell_id].src = cell_id; + } + } + + return true; + } + + bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_kv_cache & kv_self = ctx->kv_self; + uint32_t v_trans; + uint32_t n_layer; + read_to(&v_trans, sizeof(v_trans)); + read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); + return false; + } + if (kv_self.v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); + } + } + + if (!kv_self.v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + return true; + } + + void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t cell_count; + read_to(&cell_count, sizeof(cell_count)); + + bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + + if (!res) { + if (seq_id == -1) { + llama_kv_cache_clear(ctx); + } else { + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } + } +}; + +struct llama_data_write_dummy : llama_data_write { + size_t size_written = 0; + + llama_data_write_dummy() {} + + void write(const void * /* src */, size_t size) override { + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_write_buffer : llama_data_write { + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; + + llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ggml_backend_tensor_get(tensor, ptr, offset, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_read_buffer : llama_data_read { + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; + + llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + const uint8_t * read(size_t size) override { + const uint8_t * base_ptr = ptr; + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ptr += size; + size_read += size; + buf_size -= size; + return base_ptr; + } + + void read_to(void * dst, size_t size) override { + memcpy(dst, read(size), size); + } + + size_t get_size_read() override { + return size_read; + } +}; + +struct llama_data_write_file : llama_data_write { + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; + + llama_data_write_file(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_read_file : llama_data_read { + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; + + llama_data_read_file(llama_file * f) : file(f) {} + + void read_to(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + const uint8_t * read(size_t size) override { + temp_buffer.resize(size); + read_to(temp_buffer.data(), size); + return temp_buffer.data(); + } + + size_t get_size_read() override { + return size_read; + } +}; + +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_write_file data_ctx(&file); + * llama_state_get_data_internal(ctx, data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_write_buffer data_ctx(buf.data(), max_size); + * llama_state_get_data_internal(ctx, data_ctx); + * +*/ +static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) { + llama_synchronize(ctx); + + data_ctx.write_model_info(ctx); + + // copy outputs + data_ctx.write_output_ids(ctx); + data_ctx.write_logits(ctx); + data_ctx.write_embeddings(ctx); + + data_ctx.write_kv_cache(ctx); + + return data_ctx.get_size_written(); +} + +size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { + llama_data_write_buffer data_ctx(dst, size); + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +// Returns the *actual* size of the state. +// Intended to be used when saving to state to a buffer. +size_t llama_state_get_size(struct llama_context * ctx) { + llama_data_write_dummy data_ctx; + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { + llama_synchronize(ctx); + + data_ctx.read_model_info(ctx); + + // set outputs + data_ctx.read_output_ids(ctx); + data_ctx.read_logits(ctx); + data_ctx.read_embeddings(ctx); + + data_ctx.read_kv_cache(ctx); + + return data_ctx.get_size_read(); +} + +// Sets the state reading from the specified source address +size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { + llama_data_read_buffer data_ctx(src, size); + try { + return llama_state_set_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(path_session, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size() - file.tell(); + + llama_data_read_file data_ctx(&file); + const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); + + if (n_read != n_state_size_cur) { + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + return false; + } + } + return true; +} + +bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); + return false; + } +} + +static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + llama_file file(path_session, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(&file); + llama_state_get_data_internal(ctx, data_ctx); + + return true; +} + +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); + return false; + } +} + +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + + data_ctx.write_kv_cache(ctx, seq_id); + + return data_ctx.get_size_written(); +} + +size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { + llama_data_write_dummy data_ctx; + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); +} + +size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + llama_data_write_buffer data_ctx(dst, size); + try { + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { + llama_synchronize(ctx); + + data_ctx.read_kv_cache(ctx, dest_seq_id); + + return data_ctx.get_size_read(); +} + +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { + llama_data_read_buffer data_ctx(src, size); + try { + return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(&file); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + return res; +} + +static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size() - file.tell(); + llama_data_read_file data_ctx(&file); + const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +) { + return ctx->model.tensors_by_name; +} diff --git a/src/llama-context.h b/src/llama-context.h new file mode 100644 index 000000000..0d163c470 --- /dev/null +++ b/src/llama-context.h @@ -0,0 +1,128 @@ +#pragma once + +#include "llama.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-model.h" +#include "llama-kv-cache.h" +#include "llama-adapter.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include + +struct llama_context { + llama_context(const llama_model & model) + : model(model) + , t_start_us(model.t_start_us) + , t_load_us(model.t_load_us) {} + + const struct llama_model & model; + + struct llama_cparams cparams; + struct llama_sbatch sbatch; // TODO: revisit if needed + struct llama_kv_cache kv_self; + struct llama_control_vector cvec; + + std::unordered_map lora_adapters; + + std::vector backends; + std::vector> set_n_threads_fns; + + ggml_backend_t backend_cpu = nullptr; + + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + + bool has_evaluated_once = false; + + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; + + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; + + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls + + // host buffer for the model output (logits and embeddings) + ggml_backend_buffer_ptr buf_output; + + // decode output (2-dimensional array: [n_outputs][n_vocab]) + size_t logits_size = 0; // capacity (of floats) for logits + float * logits = nullptr; + + std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + size_t output_size = 0; // capacity (of tokens positions) for the output buffers + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + + bool logits_all = false; + + // embeddings output (2-dimensional array: [n_outputs][n_embd]) + // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE + size_t embd_size = 0; // capacity (of floats) for embeddings + float * embd = nullptr; + + // sequence embeddings output (map of [n_embd] vectors) + // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE + std::map> embd_seq; + + // whether we are computing encoder output or decoder output + bool is_encoding = false; + + // TODO: find a better way to accommodate mutli-dimension position encoding methods + // number of position id each token get, 1 for each token in most cases. + // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate. + int n_pos_per_token = 1; + + // output of the encoder part of the encoder-decoder models + std::vector embd_enc; + std::vector> seq_ids_enc; + + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_ptr sched; + + ggml_abort_callback abort_callback = nullptr; + void * abort_callback_data = nullptr; + + // input tensors + struct ggml_tensor * inp_tokens; // I32 [n_batch] + struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_out_ids; // I32 [n_outputs] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + struct ggml_tensor * inp_cls; // I32 [n_batch] + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] +}; + +// TODO: make these methods of llama_context +void llama_set_k_shift(struct llama_context & lctx); + +void llama_set_s_copy(struct llama_context & lctx); + +void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch); + +// Make sure enough space is available for outputs. +// Returns max number of outputs for which space was reserved. +size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs); + +// make the outputs have the same order they had in the user-provided batch +void llama_output_reorder(struct llama_context & ctx); + +// For internal test use +// TODO: remove +const std::vector> & llama_internal_get_tensor_map(struct llama_context * ctx); diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp new file mode 100644 index 000000000..28369be36 --- /dev/null +++ b/src/llama-cparams.cpp @@ -0,0 +1 @@ +#include "llama-cparams.h" diff --git a/src/llama-cparams.h b/src/llama-cparams.h new file mode 100644 index 000000000..252012f3d --- /dev/null +++ b/src/llama-cparams.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" + +#include + +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_ubatch; + uint32_t n_seq_max; + int n_threads; // number of threads to use for generation + int n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + uint32_t n_ctx_orig_yarn; + // These hyperparameters are not exposed in GGUF, because all + // existing YaRN models use the same values for them. + float yarn_ext_factor; + float yarn_attn_factor; + float yarn_beta_fast; + float yarn_beta_slow; + float defrag_thold; + + bool embeddings; + bool causal_attn; + bool offload_kqv; + bool flash_attn; + bool no_perf; + + enum llama_pooling_type pooling_type; + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; +}; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 76d0cb3a2..186dc9a25 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1,5 +1,6 @@ #include "llama-grammar.h" +#include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 13e940fb5..f8b40c651 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -1,8 +1,10 @@ #pragma once -#include "llama-impl.h" +#include "llama.h" #include +#include +#include struct llama_vocab; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp new file mode 100644 index 000000000..c40534696 --- /dev/null +++ b/src/llama-hparams.cpp @@ -0,0 +1,71 @@ +#include "llama-hparams.h" + +#include "ggml.h" + +uint32_t llama_hparams::n_head(uint32_t il) const { + if (il < n_layer) { + return n_head_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_head_kv(uint32_t il) const { + if (il < n_layer) { + return n_head_kv_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_ff(uint32_t il) const { + if (il < n_layer) { + return n_ff_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_gqa(uint32_t il) const { + const uint32_t n_head = this->n_head(il); + const uint32_t n_head_kv = this->n_head_kv(il); + + if (n_head_kv == 0) { + return 0; + } + + return n_head/n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_k * n_head_kv; +} + +uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_v * n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_s() const { + if (wkv_head_size != 0) { + // for RWKV models + return 2 * n_embd; + } + + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; +} + +uint32_t llama_hparams::n_embd_v_s() const { + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } + + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h new file mode 100644 index 000000000..a29f20ec4 --- /dev/null +++ b/src/llama-hparams.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama.h" + +#include + +// bump if necessary +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 + +enum llama_expert_gating_func_type { + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, +}; + +struct llama_hparams_posnet { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams_convnext { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams { + bool vocab_only; + bool rope_finetuned; + bool use_par_res; + bool swin_norm; + + uint32_t n_vocab = 0; + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_embd; + uint32_t n_embd_features = 0; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; + + // for WavTokenizer + struct llama_hparams_posnet posnet; + struct llama_hparams_convnext convnext; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; + + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; + uint32_t n_expert_shared = 0; + uint32_t n_norm_groups = 0; + + float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; + + float f_norm_eps; + float f_norm_rms_eps; + float f_norm_group_eps; + + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + + float rope_attn_factor = 1.0f; + float rope_freq_base_train; + float rope_freq_scale_train; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; + + std::array rope_sections; + + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + + bool ssm_dt_b_c_rms = false; + + float f_clamp_kqv = 0.0f; + float f_max_alibi_bias = 0.0f; + float f_logit_scale = 0.0f; + + // Additional scale factors (Granite/Granite MoE) + float f_residual_scale = 0.0f; + float f_embedding_scale = 0.0f; + float f_attention_scale = 0.0f; + + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; + enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + uint32_t n_head(uint32_t il = 0) const; + + uint32_t n_head_kv(uint32_t il = 0) const; + + uint32_t n_ff(uint32_t il = 0) const; + + uint32_t n_gqa(uint32_t il = 0) const; + + // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t il = 0) const; + + // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t il = 0) const; + + // dimension of the rolling state embeddings + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + uint32_t n_embd_k_s() const; + + // dimension of the recurrent state embeddings + uint32_t n_embd_v_s() const; +}; + +static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); + diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp new file mode 100644 index 000000000..6ec709dd3 --- /dev/null +++ b/src/llama-impl.cpp @@ -0,0 +1,167 @@ +#include "llama-impl.h" + +#include "gguf.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +struct llama_logger_state { + ggml_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_logger_state g_logger_state; + +time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + +time_meas::~time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } + } + +void llama_log_set(ggml_log_callback log_callback, void * user_data) { + ggml_log_set(log_callback, user_data); + g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default; + g_logger_state.log_callback_user_data = user_data; +} + +static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +void llama_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return format("unknown type %d", type); + } +} + +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} diff --git a/src/llama-impl.h b/src/llama-impl.h index 70f16b61c..12d1fb082 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,10 +1,9 @@ #pragma once -#include "llama.h" +#include "ggml.h" // for ggml_log_level #include #include -#include #ifdef __GNUC__ #ifdef __MINGW32__ @@ -35,147 +34,28 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * // helpers // -struct time_meas { - time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; - ~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } - } +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false); + ~time_meas(); const int64_t t_start_us; int64_t & t_acc; }; -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} +void replace_all(std::string & s, const std::string & search, const std::string & replace); -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); +// TODO: rename to llama_format ? +LLAMA_ATTRIBUTE_FORMAT(1, 2) +std::string format(const char * fmt, ...); -// the ring buffer works similarly to std::deque, but with a fixed capacity -template -struct ring_buffer { - ring_buffer(size_t cap) : capacity(cap), data(cap) {} +std::string llama_format_tensor_shape(const std::vector & ne); +std::string llama_format_tensor_shape(const struct ggml_tensor * t); - T & front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - const T & front() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - T & back() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - const T & back() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - void push_back(const T & value) { - if (capacity == 0) { - throw std::runtime_error("ring buffer: capacity is zero"); - } - - if (sz == capacity) { - // advance the start when buffer is full - first = (first + 1) % capacity; - } else { - sz++; - } - data[pos] = value; - pos = (pos + 1) % capacity; - } - - T pop_front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - T value = data[first]; - first = (first + 1) % capacity; - sz--; - return value; - } - - //T & operator[](size_t i) { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - //const T & at(size_t i) const { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - const T & rat(size_t i) const { - if (i >= sz) { - throw std::runtime_error("ring buffer: index out of bounds"); - } - return data[(first + sz - i - 1) % capacity]; - } - - std::vector to_vector() const { - std::vector result; - result.reserve(sz); - for (size_t i = 0; i < sz; i++) { - result.push_back(data[(first + i) % capacity]); - } - return result; - } - - void clear() { - // here only reset the status of the buffer - sz = 0; - first = 0; - pos = 0; - } - - bool empty() const { - return sz == 0; - } - - size_t size() const { - return sz; - } - - size_t capacity = 0; - size_t sz = 0; - size_t first = 0; - size_t pos = 0; - std::vector data; -}; +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp new file mode 100644 index 000000000..90b6c56ed --- /dev/null +++ b/src/llama-kv-cache.cpp @@ -0,0 +1,718 @@ +#include "llama-kv-cache.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-model.h" + +#include +#include +#include + +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + +uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + +bool llama_kv_cache_init( + struct llama_kv_cache & cache, + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) { + const struct llama_hparams & hparams = model.hparams; + + const int32_t n_layer = hparams.n_layer; + + cache.has_shift = false; + + cache.recurrent = llama_model_is_recurrent(&model); + cache.v_trans = !cache.recurrent && !cparams.flash_attn; + cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + + LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", + __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift); + + cache.head = 0; + cache.size = kv_size; + cache.used = 0; + + cache.type_k = type_k; + cache.type_v = type_v; + + cache.cells.clear(); + cache.cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + ctx_map[buft] = ctx; + cache.ctxs.emplace_back(ctx); + return ctx; + } + return it->second; + }; + + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa); + + ggml_backend_buffer_type_t buft; + if (offload) { + auto * dev = model.dev_layer.at(i).dev; + buft = ggml_backend_dev_buffer_type(dev); + } else { + buft = ggml_backend_cpu_buffer_type(); + } + ggml_context * ctx = ctx_for_buft(buft); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); + return false; + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + cache.bufs.emplace_back(buf); + } + + return true; +} + +struct llama_kv_cache_slot_info llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + if (cache.recurrent) { + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = cache.size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + return llama_kv_cache_slot_info_failed; + } + if (j > 0) { + llama_kv_cell & seq = cache.cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + cache.used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(cache.size, -1); + for (uint32_t i = 0; i < cache.size; ++i) { + llama_kv_cell & cell = cache.cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < cache.size; ++i) { + if (tails_verif[i] != cache.cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = cache.head; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + llama_kv_cell & seq_meta = cache.cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cache.cells[dst_id]; + llama_kv_cell & src_cell = cache.cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cache.cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cache.cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cache.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cache.cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + cache.head = min; + cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + + // sanity check + return llama_kv_cache_slot_info(cache.n >= n_seqs); + } + // otherwise, one cell per token. + + if (n_tokens > cache.size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); + return llama_kv_cache_slot_info_failed; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > cache.size) { + n_tested += cache.size - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cache.size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return llama_kv_cache_slot_info_failed; + } + } + + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.cells[cache.head + k].pos = ubatch.pos[k]; + + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); + } + } + } + + cache.used += n_tokens; + + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); +} + +uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_kv_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +void llama_kv_cache_clear(struct llama_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + cache.cells[i].src = -1; + cache.cells[i].tail = -1; + } + cache.head = 0; + cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + // models like Mamba or RWKV can't have a state partially erased + if (cache.recurrent) { + if (seq_id >= (int64_t) cache.size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cache.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].is_empty()) { + // keep count of the number of used cells + if (cache.cells[i].pos >= 0) cache.used--; + + cache.cells[i].pos = -1; + cache.cells[i].src = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + + return true; +} + +void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + if (cache.recurrent) { + if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { + llama_kv_cell & tail_src = cache.cells[seq_id_src]; + llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.delta = -1; + cell_dst.src = -1; + cache.used -= 1; + } + } + if (tail_src.tail >= 0) { + llama_kv_cell & cell_src = cache.cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } + + return; + } + // otherwise, this is the KV cache of a Transformer-like model + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.recurrent && (llama_seq_id) i != seq_id) { + cache.cells[i].tail = -1; + } + if (!cache.cells[i].has_seq_id(seq_id)) { + if (cache.cells[i].pos >= 0) cache.used--; + cache.cells[i].pos = -1; + cache.cells[i].src = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; +} + +void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; + + if (cache.recurrent) { + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + } + return; + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + cache.cells[i].pos += delta; + cache.cells[i].delta += delta; + + if (cache.cells[i].pos < 0) { + if (!cache.cells[i].is_empty()) { + cache.used--; + } + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; +} + +void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; + + if (cache.recurrent) { + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } + return; + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + +llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { + llama_pos result = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id)) { + result = std::max(result, cache.cells[i].pos); + } + } + + return result; +} + +void llama_kv_cache_defrag(struct llama_kv_cache & cache) { + if (!cache.recurrent) { + cache.do_defrag = true; + } +} + +int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) { + int result = 0; + + for (uint32_t i = 0; i < kv.size; i++) { + result += kv.cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) { + return kv.used; +} + +bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) { + return kv.can_shift; +} + +// +// kv cache view +// + +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) { + struct llama_kv_cache_view result = { + /*.n_cells = */ 0, + /*.n_seq_max = */ n_seq_max, + /*.token_count = */ 0, + /*.used_cells = */ llama_get_kv_cache_used_cells(kv), + /*.max_contiguous = */ 0, + /*.max_contiguous_idx = */ -1, + /*.cells = */ nullptr, + /*.cells_sequences = */ nullptr, + }; + + return result; +} + +void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { + if (view->cells != nullptr) { + free(view->cells); + view->cells = nullptr; + } + if (view->cells_sequences != nullptr) { + free(view->cells_sequences); + view->cells_sequences = nullptr; + } +} + +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) { + if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) { + view->n_cells = int32_t(kv.size); + void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); + view->cells = (struct llama_kv_cache_view_cell *)p; + p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); + view->cells_sequences = (llama_seq_id *)p; + } + + const std::vector & kv_cells = kv.cells; + llama_kv_cache_view_cell * c_curr = view->cells; + llama_seq_id * cs_curr = view->cells_sequences; + int32_t used_cells = 0; + int32_t token_count = 0; + int32_t curr_contig_idx = -1; + uint32_t max_contig = 0; + int32_t max_contig_idx = -1; + + for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) { + const size_t curr_size = kv_cells[i].seq_id.size(); + token_count += curr_size; + c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; + + if (curr_size > 0) { + if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { + max_contig = i - curr_contig_idx; + max_contig_idx = curr_contig_idx; + } + curr_contig_idx = -1; + } else if (curr_contig_idx < 0) { + curr_contig_idx = i; + } + + int seq_idx = 0; + for (const llama_seq_id it : kv_cells[i].seq_id) { + if (seq_idx >= view->n_seq_max) { + break; + } + cs_curr[seq_idx] = it; + seq_idx++; + } + if (seq_idx != 0) { + used_cells++; + } + for (; seq_idx < view->n_seq_max; seq_idx++) { + cs_curr[seq_idx] = -1; + } + } + if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { + max_contig_idx = curr_contig_idx; + max_contig = kv_cells.size() - curr_contig_idx; + } + view->max_contiguous = max_contig; + view->max_contiguous_idx = max_contig_idx; + view->token_count = token_count; + view->used_cells = used_cells; + if (uint32_t(used_cells) != kv.used) { + LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", + __func__, kv.used, used_cells); + } +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h new file mode 100644 index 000000000..dca6f3998 --- /dev/null +++ b/src/llama-kv-cache.h @@ -0,0 +1,218 @@ +#pragma once + +#include "llama.h" + +#include "ggml-cpp.h" + +#include +#include + +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const llama_kv_cell & other) const { + return seq_id == other.seq_id; + } +}; + +// ring-buffer of cached KV data +struct llama_kv_cache { + bool has_shift = false; + bool do_defrag = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed + bool can_shift = false; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; + } + + // TODO: better data structures to reduce the cost of this operation + llama_pos max_pos() const { + llama_pos max_pos = -1; + for (const auto & cell : cells) { + max_pos = std::max(max_pos, cell.pos); + } + + return max_pos; + } +}; + +// a structure holds information about the slot found in llama_kv_cache_find_slot +struct llama_kv_cache_slot_info { + std::pair boundaries; // slot boundaries [begin, end) + bool found = false; // the slot was found + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; + +// TODO: maybe not needed +uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams); + +bool llama_kv_cache_init( + struct llama_kv_cache & cache, + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload); + +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +// returns a structure holding information about the slot found +// Note: On success, it's important that cache.head points +// to the first cell of the slot. +struct llama_kv_cache_slot_info llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_ubatch & batch); + +// find how many cells are currently in use +uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache); + +void llama_kv_cache_clear(struct llama_kv_cache & cache); + +bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + +void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + +void llama_kv_cache_seq_keep( + struct llama_kv_cache & cache, + llama_seq_id seq_id); + +void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + +void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + +llama_pos llama_kv_cache_seq_pos_max( + struct llama_kv_cache & cache, + llama_seq_id seq_id); + +void llama_kv_cache_defrag(struct llama_kv_cache & cache); + +int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv); + +int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv); + +bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv); + +// +// kv cache view +// + +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max); + +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv); + +// +// kv cache restore +// + +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + // for non-recurrent models only + // list of slots to restore + std::vector> slot_boundaries; + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + // saves a slot information for future restoration + void save(const struct llama_kv_cache_slot_info & slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } + } + } + + // must be explicitly called to restore the kv_cache state + // and rollback changes from all llama_kv_cache_find_slot calls + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; + diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp new file mode 100644 index 000000000..a8cb9439b --- /dev/null +++ b/src/llama-mmap.cpp @@ -0,0 +1,589 @@ +#include "llama-mmap.h" + +#include "llama-impl.h" + +#include "ggml.h" + +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +// TODO: consider moving to llama-impl.h if needed in more places +#if defined(_WIN32) +std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +// llama_file + +struct llama_file::impl { +#if defined(_WIN32) + HANDLE fp_win32; + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %lx", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#else + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#endif + + FILE * fp; + size_t size; +}; + +llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::~llama_file() = default; + +size_t llama_file::tell() const { return pimpl->tell(); } +size_t llama_file::size() const { return pimpl->size; } + +int llama_file::file_id() const { +#ifdef _WIN32 + return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); +#else + return ::fileno(pimpl->fp); +#endif +#endif +} + +void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } +void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } + +uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } + +void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } +void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } + +// llama_mmap + +struct llama_mmap::impl { +#ifdef _POSIX_MAPPED_FILES + std::vector> mapped_fragments; + + impl(struct llama_file * file, size_t prefetch, bool numa) { + size = file->size(); + int fd = file->file_id(); + int flags = MAP_SHARED; + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + + mapped_fragments.emplace_back(0, file->size()); + } + + static void align_range(size_t * first, size_t * last, size_t page_size) { + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + void unmap_fragment(size_t first, size_t last) { + int page_size = sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + } else { + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); + } + + ~impl() { + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } + } +#elif defined(_WIN32) + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(numa); + + size = file->size(); + + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + + if (hMapping == NULL) { + DWORD error = GetLastError(); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + DWORD error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch > 0) { +#if _WIN32_WINNT >= 0x602 + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); + + if (pPrefetchVirtualMemory) { + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T) std::min(size, prefetch); + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + throw std::runtime_error("PrefetchVirtualMemory unavailable"); +#endif + } + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + } + + ~impl() { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(file); + GGML_UNUSED(prefetch); + GGML_UNUSED(numa); + + throw std::runtime_error("mmap not supported"); + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + + throw std::runtime_error("mmap not supported"); + } +#endif + + void * addr; + size_t size; +}; + +llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique(file, prefetch, numa)) {} +llama_mmap::~llama_mmap() = default; + +size_t llama_mmap::size() const { return pimpl->size; } +void * llama_mmap::addr() const { return pimpl->addr; } + +void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mmap::SUPPORTED = true; +#else +const bool llama_mmap::SUPPORTED = false; +#endif + +// llama_mlock + +struct llama_mlock::impl { +#ifdef _POSIX_MEMLOCK_RANGE + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n" +#endif + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } + + LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + size_t increment = len + 1048576; + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + LLAMA_LOG_WARN("warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif + + impl() : addr(NULL), size(0), failed_already(false) {} + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + + void * addr; + size_t size; + + bool failed_already; +}; + +llama_mlock::llama_mlock() : pimpl(std::make_unique()) {} +llama_mlock::~llama_mlock() = default; + +void llama_mlock::init(void * ptr) { pimpl->init(ptr); } +void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mlock::SUPPORTED = true; +#else +const bool llama_mlock::SUPPORTED = false; +#endif + +size_t llama_path_max() { + return PATH_MAX; +} diff --git a/src/llama-mmap.h b/src/llama-mmap.h new file mode 100644 index 000000000..1da9ecb6b --- /dev/null +++ b/src/llama-mmap.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +struct llama_file; +struct llama_mmap; +struct llama_mlock; + +using llama_files = std::vector>; +using llama_mmaps = std::vector>; +using llama_mlocks = std::vector>; + +struct llama_file { + llama_file(const char * fname, const char * mode); + ~llama_file(); + + size_t tell() const; + size_t size() const; + + int file_id() const; // fileno overload + + void seek(size_t offset, int whence) const; + + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + + void write_raw(const void * ptr, size_t len) const; + void write_u32(uint32_t val) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mmap { + llama_mmap(const llama_mmap &) = delete; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false); + ~llama_mmap(); + + size_t size() const; + void * addr() const; + + void unmap_fragment(size_t first, size_t last); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mlock { + llama_mlock(); + ~llama_mlock(); + + void init(void * ptr); + void grow_to(size_t target_size); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +size_t llama_path_max(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp new file mode 100644 index 000000000..1c4e30878 --- /dev/null +++ b/src/llama-model-loader.cpp @@ -0,0 +1,1011 @@ +#include "llama-model-loader.h" + +#include "ggml.h" + +#include +#include +#include +#include + +const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo { + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); + return ArrayInfo { + arr_type, + size_t(gguf_get_arr_n(ctx, k)), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV : public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) { + if (!ovrd) { return false; } + if (ovrd->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(ovrd->tag), ovrd->key); + switch (ovrd->tag) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_TYPE_INT: { + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(ovrd->tag), ovrd->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { + target = ovrd->val_bool; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { + target = ovrd->val_i64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { + target = ovrd->val_f64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + if (try_override(target, ovrd)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, ovrd); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, key.c_str(), target, ovrd); + } + }; +} + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + + result = arr_info.length; + return true; + } + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(enum llm_kv kid, T & result, bool required) { + return get_arr_n(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required); + + template + bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; + } + + template + bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; + } + + template + bool llama_model_loader::get_arr(enum llm_kv kid, T & result, bool required) { + return get_arr(llm_kv(kid), result, required); + } + + template + bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; + } + + template + bool llama_model_loader::get_key(enum llm_kv kid, T & result, bool required) { + return get_key(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); + template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); + + template<> + bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) { + uint32_t tmp; + const bool found = get_key(kid, tmp, required); + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } + return found; + } + + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } + + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } + + template + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required) { + return get_key_or_arr(llm_kv(kid), result, n, required); + } + + // TODO: this is not very clever - figure out something better + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { + int trace = 0; + if (getenv("LLAMA_TRACE")) { + trace = atoi(getenv("LLAMA_TRACE")); + } + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + meta.reset(gguf_init_from_file(fname.c_str(), params)); + if (!meta) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + uint16_t idx = 0; + get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + } + + std::vector split_prefix(llama_path_max(), 0); + if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) { + throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + std::vector split_path(llama_path_max(), 0); + for (idx = 1; idx < n_split; idx++) { + llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data())); + } + + files.emplace_back(new llama_file(split_path.data(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); + } + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check + { + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } + } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + + n_kv = gguf_get_n_kv(meta.get()); + n_tensors = weights_map.size(); + + fver = (enum llama_fver) gguf_get_version(meta.get()); + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (const auto & it : weights_map) { + const llama_tensor_weight & w = it.second; + const ggml_tensor * tensor = w.tensor; + + enum ggml_type type = tensor->type; + + n_type[type]++; + + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; + } + + if (trace > 0) { + const uint16_t sid = w.idx; + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + } + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; + case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; + case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid); + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(meta.get(), i); + const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(meta.get(), i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + this->check_tensors = check_tensors; +} + +std::string llama_model_loader::get_arch_name() const { + return arch_name; +} + +enum llm_arch llama_model_loader::get_arch() const { + return llm_kv.arch; +} + +const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const { + auto pos = weights_map.find(name); + if (pos != weights_map.end()) { + return &pos->second; + } + + return nullptr; +} + +const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const { + const llama_tensor_weight * weight = get_weight(name); + if (!weight) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return *weight; +} + +struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const { + const auto * weight = get_weight(name); + if (!weight) { + return nullptr; + } + return weight->tensor; +} + +struct ggml_tensor * llama_model_loader::require_tensor_meta(const std::string & name) const { + struct ggml_tensor * tensor = get_tensor_meta(name.c_str()); + if (!tensor) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + return tensor; +} + +const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const { + const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return cur; +} + +struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); + + if (cur == NULL) { + return NULL; + } + + bool duplicated = flags & TENSOR_DUPLICATED; + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); + ggml_set_name(tensor, ggml_get_name(cur)); + + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; + +} + +struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + + if (cur == NULL) { + return NULL; + } + + if (cur->type != base->type) { + throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type))); + } + + std::array dims; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + dims[i] = i < ne.size() ? ne.begin()[i] : 1; + } + + struct ggml_tensor * tensor = ggml_view_4d(ctx, base, + dims[0], dims[1], dims[2], dims[3], + cur->nb[1], cur->nb[2], cur->nb[3], + offset); + + ggml_set_name(tensor, name.c_str()); + + n_created++; + + return tensor; +} + +void llama_model_loader::done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } +} + +void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) { + if (use_mmap) { + mappings.reserve(files.size()); + mmaps_used.reserve(files.size()); + for (const auto & file : files) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + mmaps_used.emplace_back(mapping->size(), 0); + if (mlock_mmaps) { + std::unique_ptr mlock_mmap(new llama_mlock()); + mlock_mmap->init(mapping->addr()); + mlock_mmaps->emplace_back(std::move(mlock_mmap)); + } + mappings.emplace_back(std::move(mapping)); + } + } + + // compute the total size of all tensors for progress reporting + for (const auto & it : weights_map) { + size_data += ggml_nbytes(it.second.tensor); + } +} + +void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const { + GGML_ASSERT(!mappings.empty()); + const auto & mapping = mappings.at(idx); + + *first = mapping->size(); + *last = 0; + *addr = mapping->addr(); + for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { + const auto * weight = get_weight(ggml_get_name(tensor)); + if (!weight || weight->idx != idx) { + continue; + } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + } +} + +void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { + const auto & w = require_weight(ggml_get_name(cur)); + + if (use_mmap) { + const auto & mapping = mappings.at(w.idx); + if (cur->data == nullptr) { + cur->data = (uint8_t *)mapping->addr() + w.offs; + } else { + memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur)); + } + } else { + GGML_ASSERT(cur->data != nullptr); + GGML_ASSERT(w.idx < files.size()); + const auto & file = files.at(w.idx); + file->seek(w.offs, SEEK_SET); + file->read_raw(cur->data, ggml_nbytes(cur)); + } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } +} + +bool llama_model_loader::load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + GGML_ASSERT(size_data != 0 && "call init_mappings() first"); + + std::vector> read_buf; + std::vector>> validation_result; + + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector events; + std::vector host_ptrs; + size_t buffer_idx = 0; // buffer to use for async loads + ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t { + if (use_mmap || check_tensors) { + return nullptr; + } + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the backend supports the necessary features for async uploads. + auto * buf = bufs.count(0) ? bufs.at(0) : nullptr; + if (!buf) { + LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func); + return nullptr; + } + + auto * buft = ggml_backend_buffer_get_type(buf); + auto * dev = ggml_backend_buft_get_device(buft); + if (!dev) { + LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func, + ggml_backend_buft_name(buft)); + return nullptr; + } + + if (buft != ggml_backend_dev_buffer_type(dev)) { + LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func, + ggml_backend_buft_name(buft), ggml_backend_dev_name(dev)); + return nullptr; + } + + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) { + LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (!host_buft) { + LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + // If the backend is supported, create pinned memory buffers and events for synchronisation. + for (size_t idx = 0; idx < n_buffers; ++idx) { + auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { + LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + host_buffers.emplace_back(buf); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf)); + + auto * event = ggml_backend_event_new(dev); + if (!event) { + LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + events.emplace_back(event); + } + + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + return backend; + }(__func__); + + if (upload_backend) { + LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__, + ggml_backend_dev_name(ggml_backend_get_device(upload_backend)), + ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))), + ggml_backend_name(upload_backend)); + } + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { + const auto * weight = get_weight(ggml_get_name(cur)); + if (weight == nullptr) { + // this can happen with split experts models + continue; + } + + if (progress_callback) { + if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { + return false; + } + } + + size_t n_size = ggml_nbytes(cur); + + if (use_mmap) { + const auto & mapping = mappings.at(weight->idx); + ggml_backend_buffer_t buf_mmap = nullptr; + if (bufs.count(weight->idx)) { + buf_mmap = bufs.at(weight->idx); + } + uint8_t * data = (uint8_t *) mapping->addr() + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated + if (buf_mmap && cur->data == nullptr) { + ggml_backend_tensor_alloc(buf_mmap, cur, data); + if (lmlocks) { + const auto & lmlock = lmlocks->at(weight->idx); + lmlock->grow_to(weight->offs + n_size); + } + + auto & mmap_used = mmaps_used[weight->idx]; + mmap_used.first = std::min(mmap_used.first, weight->offs); + mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); + } else { + ggml_backend_tensor_set(cur, data, 0, n_size); + } + } else { + const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } + } else { + // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (upload_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx], upload_backend); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } else { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } + } + } + + size_done += n_size; + } + + // free temporary resources used for async uploads + for (auto * event : events) { + ggml_backend_event_synchronize(event); + ggml_backend_event_free(event); + } + for (auto * buf : host_buffers) { + ggml_backend_buffer_free(buf); + } + ggml_backend_free(upload_backend); + + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + + // check if this is the last call and do final cleanup + if (size_done >= size_data) { + // unmap offloaded tensors and metadata + if (use_mmap) { + for (uint32_t idx = 0; idx < mappings.size(); idx++) { + const auto & mmap_used = mmaps_used.at(idx); + auto & mapping = mappings.at(idx); + mapping->unmap_fragment(0, mmap_used.first); + if (mmap_used.second != 0) { + mapping->unmap_fragment(mmap_used.second, mapping->size()); + } + } + } + if (progress_callback) { + // Even though the model is done loading, we still honor + // cancellation since we need to free allocations. + return progress_callback(1.0f, progress_callback_user_data); + } + } + + return true; +} diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h new file mode 100644 index 000000000..1ec478195 --- /dev/null +++ b/src/llama-model-loader.h @@ -0,0 +1,158 @@ +#pragma once + +#include "llama.h" + +#include "llama-impl.h" +#include "llama-arch.h" +#include "llama-mmap.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include + +using llama_buf_map = std::unordered_map; + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +const char * llama_file_version_name(llama_fver version); + +struct llama_model_loader { + // Holds information on a model weight + struct llama_tensor_weight { + uint16_t idx; // source file index + size_t offs; // tensor data offset in the original file + + ggml_tensor * tensor; + + llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor)); + if (tensor_idx < 0) { + throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor))); + } + + offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor))); + } + } + }; + + // custom comparator to sort weights more nicely by layer + struct weight_name_comparer { + bool operator()(const std::string & a, const std::string & b) const { + int a_layer = -1; + int b_layer = -1; + sscanf(a.c_str(), "blk.%d.", &a_layer); + sscanf(b.c_str(), "blk.%d.", &b_layer); + if (a_layer != b_layer) { + return a_layer < b_layer; + } + return a < b; + } + }; + + static const int TENSOR_NOT_REQUIRED = 1; + static const int TENSOR_DUPLICATED = 2; + + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + uint64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + bool check_tensors; + + llama_files files; + llama_ftype ftype; + llama_fver fver; + + llama_mmaps mappings; + + std::map weights_map; + std::unordered_map kv_overrides; + + gguf_context_ptr meta; + std::vector contexts; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + size_t size_done = 0; + size_t size_data = 0; + std::vector> mmaps_used; + + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, bool required = true); + + template + typename std::enable_if::value, bool>::type + get_arr_n(enum llm_kv kid, T & result, bool required = true); + + template + bool get_arr(const std::string & key, std::vector & result, bool required = true); + + template + bool get_arr(const std::string & key, std::array & result, bool required = true); + + template + bool get_arr(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key(const std::string & key, T & result, bool required = true); + + template + bool get_key(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required = true); + + template + bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + + std::string get_arch_name() const; + + enum llm_arch get_arch() const; + + const llama_tensor_weight * get_weight(const char * name) const; + + const llama_tensor_weight & require_weight(const char * name) const; + + struct ggml_tensor * get_tensor_meta(const char * name) const; + + struct ggml_tensor * require_tensor_meta(const std::string & name) const; + + const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0); + + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); + + void done_getting_tensors() const; + + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); + + void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const; + + // for backwards compatibility, does not support ggml-backend + void load_data_for(struct ggml_tensor * cur) const; + + // Returns false if cancelled by progress_callback + bool load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data); +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp new file mode 100644 index 000000000..7260cb155 --- /dev/null +++ b/src/llama-model.cpp @@ -0,0 +1,2213 @@ +#include "llama-model.h" + +#include "llama-impl.h" +#include "llama-model-loader.h" + +#include "unicode.h" // TODO: remove + +#include +#include +#include +#include +#include + +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + +const char * llm_type_name(llm_type type) { + switch (type) { + case MODEL_14M: return "14M"; + case MODEL_17M: return "17M"; + case MODEL_22M: return "22M"; + case MODEL_33M: return "33M"; + case MODEL_60M: return "60M"; + case MODEL_70M: return "70M"; + case MODEL_80M: return "80M"; + case MODEL_109M: return "109M"; + case MODEL_137M: return "137M"; + case MODEL_160M: return "160M"; + case MODEL_220M: return "220M"; + case MODEL_250M: return "250M"; + case MODEL_270M: return "270M"; + case MODEL_335M: return "335M"; + case MODEL_410M: return "410M"; + case MODEL_450M: return "450M"; + case MODEL_770M: return "770M"; + case MODEL_780M: return "780M"; + case MODEL_0_5B: return "0.5B"; + case MODEL_1B: return "1B"; + case MODEL_1_3B: return "1.3B"; + case MODEL_1_4B: return "1.4B"; + case MODEL_1_5B: return "1.5B"; + case MODEL_1_6B: return "1.6B"; + case MODEL_2B: return "2B"; + case MODEL_2_8B: return "2.8B"; + case MODEL_3B: return "3B"; + case MODEL_4B: return "4B"; + case MODEL_6B: return "6B"; + case MODEL_6_9B: return "6.9B"; + case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; + case MODEL_9B: return "9B"; + case MODEL_11B: return "11B"; + case MODEL_12B: return "12B"; + case MODEL_13B: return "13B"; + case MODEL_14B: return "14B"; + case MODEL_15B: return "15B"; + case MODEL_16B: return "16B"; + case MODEL_20B: return "20B"; + case MODEL_30B: return "30B"; + case MODEL_32B: return "32B"; + case MODEL_34B: return "34B"; + case MODEL_35B: return "35B"; + case MODEL_40B: return "40B"; + case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; + case MODEL_236B: return "236B"; + case MODEL_314B: return "314B"; + case MODEL_671B: return "671B"; + case MODEL_SMALL: return "0.1B"; + case MODEL_MEDIUM: return "0.4B"; + case MODEL_LARGE: return "0.8B"; + case MODEL_XL: return "1.5B"; + case MODEL_A1_7B: return "A1.7B"; + case MODEL_A2_7B: return "A2.7B"; + case MODEL_8x7B: return "8x7B"; + case MODEL_8x22B: return "8x22B"; + case MODEL_16x12B: return "16x12B"; + case MODEL_16x3_8B: return "16x3.8B"; + case MODEL_10B_128x3_66B: return "10B+128x3.66B"; + case MODEL_57B_A14B: return "57B.A14B"; + case MODEL_27B: return "27B"; + default: return "?B"; + } +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + +static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) { + switch (type) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax"; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + +std::string llama_model_arch_name (const llama_model & model) { + return llm_arch_name(model.arch); +} + +std::string llama_model_type_name (const llama_model & model) { + return llm_type_name(model.type); +} + +std::string llama_model_ftype_name(const llama_model & model) { + return llama_model_ftype_name(model.ftype); +} + +template +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } + } + + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + + return op_supported; +} + +template +static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } + } + + throw std::runtime_error(format("no suitable buffer type found")); +} + +ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) { + return select_buft( + *model.dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} + +struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name) { + auto it = std::find_if(model.tensors_by_name.begin(), model.tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == model.tensors_by_name.end()) { + return nullptr; + } + + return it->second; +} + +size_t llama_model_max_nodes(const llama_model & model) { + return std::max(8192, model.tensors_by_name.size()*5); +} + +static const std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, +}; + +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return (llama_rope_scaling_type) kv.first; + } + } + + return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; +} + +// NOTE: avoid ever using this except for building the token_to_piece caches +static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; +} + +void llm_load_stats(llama_model_loader & ml, llama_model & model) { + model.n_elements = ml.n_elements; + model.n_bytes = ml.n_bytes; +} + +void llm_load_arch(llama_model_loader & ml, llama_model & model) { + model.arch = ml.get_arch(); + if (model.arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } +} + +void llm_load_hparams(llama_model_loader & ml, llama_model & model) { + auto & hparams = model.hparams; + const gguf_context * ctx = ml.meta.get(); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(ctx); i++) { + enum gguf_type type = gguf_get_kv_type(ctx, i); + if (type == GGUF_TYPE_ARRAY) { + continue; + } + const char * name = gguf_get_key(ctx, i); + const std::string value = gguf_kv_to_str(ctx, i); + model.gguf_kv.emplace(name, value); + } + + // get general kv + ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); + + // get hparams kv + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); + + // everything past this point is not vocab-related + if (hparams.vocab_only) { + return; + } + + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) { + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + + ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + } + + GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + GGML_ASSERT(hparams.n_expert_used > 0); + } else { + GGML_ASSERT(hparams.n_expert_used == 0); + } + + // zero-out the array hparams + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + + bool rope_finetuned = false; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); + + // rope_freq_base (optional) + hparams.rope_freq_base_train = 10000.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false); + + std::string rope_scaling("linear"); + ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); + hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); + GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 0.0f; + if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { + // try the old key name + ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); + } + hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + + // non-transformer models do not have attention heads + if (hparams.n_head() > 0) { + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + } + } + } else { + hparams.n_rot = 0; + hparams.n_embd_head_k = 0; + hparams.n_embd_head_v = 0; + } + + using e_model = llm_type; // TMP + + // arch-specific KVs + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8x7B; break; + case 56: model.type = e_model::MODEL_8x22B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B + case 22: model.type = e_model::MODEL_1B; break; + case 26: model.type = e_model::MODEL_3B; break; + case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B + // granite uses a vocab with len 49152 + case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; + case 36: model.type = e_model::MODEL_8B; break; // granite + case 40: model.type = e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_34B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } + } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer) { + case 52: model.type = e_model::MODEL_1B; break; + case 40: model.type = e_model::MODEL_2B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: model.type = e_model::MODEL_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GROK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 64: model.type = e_model::MODEL_314B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_FALCON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 60: model.type = e_model::MODEL_40B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BAICHUAN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + if (model.type == e_model::MODEL_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } + } break; + case LLM_ARCH_STARCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + case 42: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_15B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_REFACT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 3: + model.type = e_model::MODEL_17M; break; // bge-micro + case 6: + model.type = e_model::MODEL_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small + case 768: model.type = e_model::MODEL_109M; break; // bge-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + model.type = e_model::MODEL_335M; break; // bge-large + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_NOMIC_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + model.type = e_model::MODEL_137M; + } + } break; + case LLM_ARCH_BLOOM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_MPT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_30B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_STABLELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2VL: + { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + } + // fall through + case LLM_ARCH_QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; + case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 36: model.type = e_model::MODEL_3B; break; + case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_14B; break; + case 64: model.type = e_model::MODEL_32B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_A2_7B; break; + case 28: model.type = e_model::MODEL_57B_A14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PHI2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 + if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { + // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct + hparams.n_swa = 2047; + } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { + // default value for Phi-3-mini-128k-instruct + hparams.n_swa = 262144; + } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { + // default value for Phi-3-medium-128k-instruct + hparams.n_swa = 131072; + } + bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (!found_swa && hparams.n_swa == 0) { + throw std::runtime_error("invalid value for sliding_window"); + } + } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_16x3_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PLAMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GPT2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 12: model.type = e_model::MODEL_SMALL; break; + case 24: model.type = e_model::MODEL_MEDIUM; break; + case 36: model.type = e_model::MODEL_LARGE; break; + case 48: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_CODESHELL: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_ORION: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_INTERNLM2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: model.type = e_model::MODEL_2B; break; + case 28: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA2: + { + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + hparams.attn_soft_cap = true; + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_2B; break; + case 42: model.type = e_model::MODEL_9B; break; + case 46: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 30: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_15B; break; + case 52: model.type = e_model::MODEL_20B; break; // granite + case 88: model.type = e_model::MODEL_34B; break; // granite + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_MAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_XVERSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 80: model.type = e_model::MODEL_65B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_COMMAND_R: + { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_35B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_COHERE2: + { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_DBRX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_16x12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_OLMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_A1_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_270M; break; + case 20: model.type = e_model::MODEL_450M; break; + case 28: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GPTNEOX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff()) { + case 512: model.type = e_model::MODEL_14M; break; + case 2048: model.type = e_model::MODEL_70M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: model.type = e_model::MODEL_160M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: model.type = e_model::MODEL_410M; break; + case 8192: model.type = e_model::MODEL_1_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: model.type = e_model::MODEL_2_8B; break; + case 16384: model.type = e_model::MODEL_6_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: model.type = e_model::MODEL_12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_ARCTIC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: model.type = e_model::MODEL_10B_128x3_66B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } else { + model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: model.type = e_model::MODEL_16B; break; + case 60: model.type = e_model::MODEL_236B; break; + case 61: model.type = e_model::MODEL_671B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_60M; break; // t5-small + case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: model.type = e_model::MODEL_220M; break; // t5-base + case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: model.type = e_model::MODEL_770M; break; // t5-large + case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large + case 16384: model.type = e_model::MODEL_3B; break; // t5-3b + case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl + case 65536: model.type = e_model::MODEL_11B; break; // t5-11b + case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_T5ENCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + model.type = e_model::MODEL_UNKNOWN; + } break; + case LLM_ARCH_JAIS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_3B; break; + case 40: model.type = e_model::MODEL_13B; break; + /* TODO: add variants */ + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_NEMOTRON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_EXAONE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_RWKV6: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 61: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; + default: throw std::runtime_error("unsupported model architecture"); + } + + model.ftype = ml.ftype; + + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } + + hparams.rope_type = llama_rope_type(&model); +} + +void llm_load_vocab(llama_model_loader & ml, llama_model & model) { + auto & vocab = model.vocab; + + struct gguf_context * ctx = ml.meta.get(); + + const auto kv = LLM_KV(model.arch); + + // determine vocab type + { + std::string tokenizer_model; + std::string tokenizer_pre; + + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { + vocab.type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + vocab.linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { + vocab.n_vocab = 0; + LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab); + } + return; + } + + if (tokenizer_model == "llama") { + vocab.type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = 100; + vocab.special_sep_id = 102; + vocab.special_pad_id = 0; + vocab.special_cls_id = 101; + vocab.special_mask_id = 103; + } else if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + vocab.bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "t5") { + vocab.type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = 1; + vocab.special_unk_id = 2; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = 0; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } + } else if (tokenizer_model == "rwkv") { + vocab.type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); + } + + // for now, only BPE models have pre-tokenizers + if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "default") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; + } else if ( + tokenizer_pre == "deepseek-llm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-coder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "falcon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || + tokenizer_pre == "jina-v1-en" || + tokenizer_pre == "jina-v2-es" || + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "refact") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "qwen2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "stablelm2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; + } else if ( + tokenizer_pre == "olmo") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + vocab.special_bos_id = LLAMA_TOKEN_NULL; + } else if ( + tokenizer_pre == "viking") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } + } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = true; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = true; + } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = false; + } else { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } + + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.tokenizer_add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false); + } + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + + vocab.n_vocab = n_vocab; + vocab.id_to_token.resize(n_vocab); + + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } + + vocab.token_to_id[word] = i; + vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); + + auto & token_data = vocab.id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } + } + GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); + + vocab.init_tokenizer(); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + try { + vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); + vocab.linefeed_id = vocab.special_pad_id; + } + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.linefeed_id = vocab.special_pad_id; + } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = llama_tokenize_internal(vocab, "\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + vocab.linefeed_id = ids[0]; + } else { + const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + vocab.linefeed_id = vocab.special_pad_id; + } else { + vocab.linefeed_id = ids[0]; + } + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, + { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id }, + }; + + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= vocab.id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + } + + // Handle add_bos_token and add_eos_token + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + vocab.tokenizer_add_bos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + vocab.tokenizer_add_eos = temp; + } + } + + // auto-detect special tokens by text + // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... + // for now, we apply this workaround to find the tokens based on their text + + for (const auto & t : vocab.token_to_id) { + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + if (vocab.special_eot_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eot_id|>" + || t.first == "<|im_end|>" + || t.first == "<|end|>" + || t.first == "" + || t.first == "<|endoftext|>" + || t.first == "" + || t.first == "<|end▁of▁sentence|>" // DeepSeek + ) { + vocab.special_eot_id = t.second; + if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find EOM token: "<|eom_id|>" + if (vocab.special_eom_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eom_id|>" + ) { + vocab.special_eom_id = t.second; + if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    vocab.special_fim_pre_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_suf_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_mid_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_pad_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_rep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    vocab.special_fim_sep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        vocab.special_eog_ids.clear();
+
+        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
+        }
+
+        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
+        }
+
+        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
+        }
+
+        for (const auto & t : vocab.token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                vocab.special_eog_ids.insert(t.second);
+                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                vocab.cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
+            [&] (const llama_vocab::id a, const llama_vocab::id b) {
+                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache_token_to_piece(n_vocab);
+
+        for (uint32_t id = 0; id < n_vocab; ++id) {
+            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
+
+            size_cache += cache_token_to_piece[id].size();
+        }
+
+        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
+            for (auto substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
+            uint32_t current = vocab.id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            vocab.id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : vocab.cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
+    const auto & hparams = model.hparams;
+    const auto & vocab   = model.vocab;
+
+    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+
+    auto print_f = [](const std::function & f, uint32_t n) {
+        bool is_var = false;
+
+        std::vector v;
+        for (uint32_t i = 0; i < n; ++i) {
+            v.push_back(f(i));
+            if (v[i] != v[0]) {
+                is_var = true;
+            }
+        }
+
+        std::stringstream ss;
+
+        if (is_var) {
+            ss << "[";
+            for (uint32_t i = 0; i < n; ++i) {
+                ss << v[i];
+                if (i < n - 1) {
+                    ss << ", ";
+                }
+            }
+            ss << "]";
+        } else {
+            ss << v[0];
+        }
+
+        return ss.str();
+    };
+
+    // hparams
+    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
+    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, llm_arch_name(model.arch));
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
+
+    if (!hparams.vocab_only) {
+        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
+        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
+        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
+        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
+        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
+        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
+        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
+        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
+        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
+        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
+        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
+        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
+        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
+        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
+        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
+        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
+        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
+        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
+        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
+        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
+        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
+    }
+
+    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model).c_str());
+    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model).c_str());
+    if (ml.n_elements >= 1e12) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
+    } else if (ml.n_elements >= 1e9) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
+    } else if (ml.n_elements >= 1e6) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
+    } else {
+        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
+    }
+    if (ml.n_bytes < GiB) {
+        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
+    } else {
+        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+    }
+
+    // general kv
+    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
+
+    // special tokens
+    if (vocab.special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
+    if (vocab.special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
+    if (vocab.special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
+    if (vocab.special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
+    if (vocab.special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
+    if (vocab.special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
+    if (vocab.special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
+    if (vocab.special_cls_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
+    if (vocab.special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+
+    if (vocab.linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
+
+    if (vocab.special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
+    if (vocab.special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
+    if (vocab.special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
+    if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
+    if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
+    if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : vocab.special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+
+    if (model.arch == LLM_ARCH_DEEPSEEK) {
+        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
+        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
+        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
+    }
+
+    if (model.arch == LLM_ARCH_DEEPSEEK2) {
+        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
+        LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
+        LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
+        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
+        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
+        LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
+        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func));
+        LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
+    }
+
+    if (model.arch == LLM_ARCH_QWEN2MOE) {
+        LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
+    }
+
+    if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
+        LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
+        LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
+        LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
+    }
+}
+
+//
+// interface implementation
+//
+
+struct llama_model_params llama_model_default_params() {
+    struct llama_model_params result = {
+        /*.devices                     =*/ nullptr,
+        /*.n_gpu_layers                =*/ 0,
+        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
+        /*.main_gpu                    =*/ 0,
+        /*.tensor_split                =*/ nullptr,
+        /*.rpc_servers                 =*/ nullptr,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+        /*.vocab_only                  =*/ false,
+        /*.use_mmap                    =*/ true,
+        /*.use_mlock                   =*/ false,
+        /*.check_tensors               =*/ false,
+    };
+
+#ifdef GGML_USE_METAL
+    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
+    result.n_gpu_layers = 999;
+#endif
+
+    return result;
+}
+
+void llama_free_model(struct llama_model * model) {
+    llama_model_free(model);
+}
+
+void llama_model_free(struct llama_model * model) {
+    delete model;
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
+    return model->vocab.type;
+}
+
+int32_t llama_n_vocab(const struct llama_model * model) {
+    return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+int32_t llama_n_head(const struct llama_model * model) {
+    return model->hparams.n_head();
+}
+
+enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+    switch (model->arch) {
+        // these models do not use RoPE
+        case LLM_ARCH_GPT2:
+        case LLM_ARCH_GPTJ:
+        case LLM_ARCH_MPT:
+        case LLM_ARCH_REFACT:
+        case LLM_ARCH_BLOOM:
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_T5:
+        case LLM_ARCH_T5ENCODER:
+        case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            return LLAMA_ROPE_TYPE_NONE;
+
+        // use what we call a normal RoPE, operating on pairs of consecutive head values
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_DECI:
+        case LLM_ARCH_BAICHUAN:
+        case LLM_ARCH_STARCODER:
+        case LLM_ARCH_PLAMO:
+        case LLM_ARCH_ORION:
+        case LLM_ARCH_INTERNLM2:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_XVERSE:
+        case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_COHERE2:
+        case LLM_ARCH_OLMO:
+        case LLM_ARCH_ARCTIC:
+        case LLM_ARCH_DEEPSEEK:
+        case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+        case LLM_ARCH_CHAMELEON:
+            return LLAMA_ROPE_TYPE_NORM;
+
+        // the pairs of head values are offset by n_rot/2
+        case LLM_ARCH_FALCON:
+        case LLM_ARCH_GROK:
+        case LLM_ARCH_DBRX:
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_STABLELM:
+        case LLM_ARCH_BITNET:
+        case LLM_ARCH_QWEN:
+        case LLM_ARCH_QWEN2:
+        case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMO2:
+        case LLM_ARCH_OLMOE:
+        case LLM_ARCH_PHI2:
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
+        case LLM_ARCH_GEMMA:
+        case LLM_ARCH_GEMMA2:
+        case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
+        case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_NEMOTRON:
+        case LLM_ARCH_EXAONE:
+        case LLM_ARCH_MINICPM3:
+            return LLAMA_ROPE_TYPE_NEOX;
+
+        case LLM_ARCH_QWEN2VL:
+            return LLAMA_ROPE_TYPE_MROPE;
+
+        // all model arches should be listed explicitly here
+        case LLM_ARCH_UNKNOWN:
+            GGML_ABORT("unknown architecture");
+    }
+
+    return LLAMA_ROPE_TYPE_NONE;
+}
+
+float llama_rope_freq_scale_train(const struct llama_model * model) {
+    return model->hparams.rope_freq_scale_train;
+}
+
+int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_meta_count(const struct llama_model * model) {
+    return (int)model->gguf_kv.size();
+}
+
+int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
+    return snprintf(buf, buf_size, "%s %s %s",
+            llama_model_arch_name (*model).c_str(),
+            llama_model_type_name (*model).c_str(),
+            llama_model_ftype_name(*model).c_str());
+}
+
+uint64_t llama_model_size(const struct llama_model * model) {
+    return model->n_bytes;
+}
+
+uint64_t llama_model_n_params(const struct llama_model * model) {
+    return model->n_elements;
+}
+
+bool llama_model_has_encoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5:        return true;
+        case LLM_ARCH_T5ENCODER: return true;
+        default:                 return false;
+    }
+}
+
+bool llama_model_has_decoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5ENCODER: return false;
+        default:                 return true;
+    }
+}
+
+llama_token llama_model_decoder_start_token(const struct llama_model * model) {
+    return model->hparams.dec_start_token_id;
+}
+
+bool llama_model_is_recurrent(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_MAMBA:  return true;
+        case LLM_ARCH_RWKV6:  return true;
+        default:              return false;
+    }
+}
diff --git a/src/llama-model.h b/src/llama-model.h
new file mode 100644
index 000000000..424cb0f52
--- /dev/null
+++ b/src/llama-model.h
@@ -0,0 +1,392 @@
+#pragma once
+
+#include "llama.h"
+#include "llama-arch.h"
+#include "llama-hparams.h"
+#include "llama-vocab.h"
+#include "llama-mmap.h"
+
+#include "ggml-cpp.h"
+
+#include 
+
+// available models
+// TODO: this enum does not follow the enum naming convention
+enum llm_type {
+    MODEL_UNKNOWN,
+    MODEL_14M,
+    MODEL_17M,
+    MODEL_22M,
+    MODEL_33M,
+    MODEL_60M,
+    MODEL_70M,
+    MODEL_80M,
+    MODEL_109M,
+    MODEL_137M,
+    MODEL_160M,
+    MODEL_220M,
+    MODEL_250M,
+    MODEL_270M,
+    MODEL_335M,
+    MODEL_410M,
+    MODEL_450M,
+    MODEL_770M,
+    MODEL_780M,
+    MODEL_0_5B,
+    MODEL_1B,
+    MODEL_1_3B,
+    MODEL_1_4B,
+    MODEL_1_5B,
+    MODEL_1_6B,
+    MODEL_2B,
+    MODEL_2_8B,
+    MODEL_3B,
+    MODEL_4B,
+    MODEL_6B,
+    MODEL_6_9B,
+    MODEL_7B,
+    MODEL_8B,
+    MODEL_9B,
+    MODEL_11B,
+    MODEL_12B,
+    MODEL_13B,
+    MODEL_14B,
+    MODEL_15B,
+    MODEL_16B,
+    MODEL_20B,
+    MODEL_30B,
+    MODEL_32B,
+    MODEL_34B,
+    MODEL_35B,
+    MODEL_40B,
+    MODEL_65B,
+    MODEL_70B,
+    MODEL_236B,
+    MODEL_314B,
+    MODEL_671B,
+    MODEL_SMALL,
+    MODEL_MEDIUM,
+    MODEL_LARGE,
+    MODEL_XL,
+    MODEL_A1_7B,
+    MODEL_A2_7B,
+    MODEL_8x7B,
+    MODEL_8x22B,
+    MODEL_16x12B,
+    MODEL_16x3_8B,
+    MODEL_10B_128x3_66B,
+    MODEL_57B_A14B,
+    MODEL_27B,
+};
+
+struct llama_layer_posnet {
+    // resnet
+    struct ggml_tensor * norm1   = nullptr;
+    struct ggml_tensor * norm1_b = nullptr;
+
+    struct ggml_tensor * conv1   = nullptr;
+    struct ggml_tensor * conv1_b = nullptr;
+
+    struct ggml_tensor * norm2   = nullptr;
+    struct ggml_tensor * norm2_b = nullptr;
+
+    struct ggml_tensor * conv2   = nullptr;
+    struct ggml_tensor * conv2_b = nullptr;
+
+    // attention
+    struct ggml_tensor * attn_norm   = nullptr;
+    struct ggml_tensor * attn_norm_b = nullptr;
+
+    struct ggml_tensor * attn_q   = nullptr;
+    struct ggml_tensor * attn_q_b = nullptr;
+
+    struct ggml_tensor * attn_k   = nullptr;
+    struct ggml_tensor * attn_k_b = nullptr;
+
+    struct ggml_tensor * attn_v   = nullptr;
+    struct ggml_tensor * attn_v_b = nullptr;
+
+    struct ggml_tensor * attn_o   = nullptr;
+    struct ggml_tensor * attn_o_b = nullptr;
+
+    // normalize
+    struct ggml_tensor * norm   = nullptr;
+    struct ggml_tensor * norm_b = nullptr;
+};
+
+struct llama_layer_convnext {
+    struct ggml_tensor * dw   = nullptr;
+    struct ggml_tensor * dw_b = nullptr;
+
+    struct ggml_tensor * norm   = nullptr;
+    struct ggml_tensor * norm_b = nullptr;
+
+    struct ggml_tensor * pw1   = nullptr;
+    struct ggml_tensor * pw1_b = nullptr;
+
+    struct ggml_tensor * pw2   = nullptr;
+    struct ggml_tensor * pw2_b = nullptr;
+
+    struct ggml_tensor * gamma = nullptr;
+};
+
+struct llama_layer {
+    // normalization
+    struct ggml_tensor * attn_norm       = nullptr;
+    struct ggml_tensor * attn_norm_b     = nullptr;
+    struct ggml_tensor * attn_norm_2     = nullptr;
+    struct ggml_tensor * attn_norm_2_b   = nullptr;
+    struct ggml_tensor * attn_q_norm     = nullptr;
+    struct ggml_tensor * attn_q_norm_b   = nullptr;
+    struct ggml_tensor * attn_k_norm     = nullptr;
+    struct ggml_tensor * attn_k_norm_b   = nullptr;
+    struct ggml_tensor * attn_out_norm   = nullptr;
+    struct ggml_tensor * attn_out_norm_b = nullptr;
+    struct ggml_tensor * attn_q_a_norm   = nullptr;
+    struct ggml_tensor * attn_kv_a_norm  = nullptr;
+    struct ggml_tensor * attn_sub_norm   = nullptr;
+    struct ggml_tensor * attn_post_norm  = nullptr;
+    struct ggml_tensor * ffn_sub_norm    = nullptr;
+    struct ggml_tensor * attn_norm_cross = nullptr;
+    struct ggml_tensor * attn_norm_enc   = nullptr;
+
+    // attention
+    struct ggml_tensor * wq        = nullptr;
+    struct ggml_tensor * wk        = nullptr;
+    struct ggml_tensor * wv        = nullptr;
+    struct ggml_tensor * wo        = nullptr;
+    struct ggml_tensor * wqkv      = nullptr;
+    struct ggml_tensor * wq_a      = nullptr;
+    struct ggml_tensor * wq_b      = nullptr;
+    struct ggml_tensor * wkv_a_mqa = nullptr;
+    struct ggml_tensor * wkv_b     = nullptr;
+    struct ggml_tensor * wq_cross  = nullptr;
+    struct ggml_tensor * wk_cross  = nullptr;
+    struct ggml_tensor * wv_cross  = nullptr;
+    struct ggml_tensor * wo_cross  = nullptr;
+    struct ggml_tensor * wq_enc    = nullptr;
+    struct ggml_tensor * wk_enc    = nullptr;
+    struct ggml_tensor * wv_enc    = nullptr;
+    struct ggml_tensor * wo_enc    = nullptr;
+
+    // attention bias
+    struct ggml_tensor * bq   = nullptr;
+    struct ggml_tensor * bk   = nullptr;
+    struct ggml_tensor * bv   = nullptr;
+    struct ggml_tensor * bo   = nullptr;
+    struct ggml_tensor * bqkv = nullptr;
+
+    // relative position bias
+    struct ggml_tensor * attn_rel_b       = nullptr;
+    struct ggml_tensor * attn_rel_b_enc   = nullptr;
+    struct ggml_tensor * attn_rel_b_cross = nullptr;
+
+    // normalization
+    struct ggml_tensor * ffn_norm         = nullptr;
+    struct ggml_tensor * ffn_norm_b       = nullptr;
+    struct ggml_tensor * ffn_post_norm    = nullptr;
+    struct ggml_tensor * layer_out_norm   = nullptr;
+    struct ggml_tensor * layer_out_norm_b = nullptr;
+    struct ggml_tensor * ffn_norm_exps    = nullptr;
+    struct ggml_tensor * ffn_norm_enc     = nullptr;
+
+    // ff
+    struct ggml_tensor * ffn_gate     = nullptr; // w1
+    struct ggml_tensor * ffn_down     = nullptr; // w2
+    struct ggml_tensor * ffn_up       = nullptr; // w3
+    struct ggml_tensor * ffn_gate_enc = nullptr;
+    struct ggml_tensor * ffn_down_enc = nullptr;
+    struct ggml_tensor * ffn_up_enc   = nullptr;
+
+    // ff MoE
+    struct ggml_tensor * ffn_gate_inp  = nullptr;
+    struct ggml_tensor * ffn_gate_exps = nullptr;
+    struct ggml_tensor * ffn_down_exps = nullptr;
+    struct ggml_tensor * ffn_up_exps   = nullptr;
+
+    // ff shared expert (shexp)
+    struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
+    struct ggml_tensor * ffn_gate_shexp     = nullptr;
+    struct ggml_tensor * ffn_down_shexp     = nullptr;
+    struct ggml_tensor * ffn_up_shexp       = nullptr;
+
+    // ff bias
+    struct ggml_tensor * ffn_gate_b = nullptr;
+    struct ggml_tensor * ffn_down_b = nullptr; // b2
+    struct ggml_tensor * ffn_up_b   = nullptr; // b3
+    struct ggml_tensor * ffn_act    = nullptr;
+    struct ggml_tensor * ffn_exp_probs_b = nullptr;
+
+    // mamba proj
+    struct ggml_tensor * ssm_in  = nullptr;
+    struct ggml_tensor * ssm_x   = nullptr;
+    struct ggml_tensor * ssm_dt  = nullptr;
+    struct ggml_tensor * ssm_out = nullptr;
+
+    // mamba
+    struct ggml_tensor * ssm_conv1d = nullptr;
+    struct ggml_tensor * ssm_a      = nullptr;
+    struct ggml_tensor * ssm_d      = nullptr;
+
+    // mamba bias
+    struct ggml_tensor * ssm_conv1d_b = nullptr;
+    struct ggml_tensor * ssm_dt_b     = nullptr;
+
+    // rwkv
+    struct ggml_tensor * time_mix_w1         = nullptr;
+    struct ggml_tensor * time_mix_w2         = nullptr;
+    struct ggml_tensor * time_mix_lerp_x     = nullptr;
+    struct ggml_tensor * time_mix_lerp_w     = nullptr;
+    struct ggml_tensor * time_mix_lerp_k     = nullptr;
+    struct ggml_tensor * time_mix_lerp_v     = nullptr;
+    struct ggml_tensor * time_mix_lerp_r     = nullptr;
+    struct ggml_tensor * time_mix_lerp_g     = nullptr;
+
+    struct ggml_tensor * time_mix_first      = nullptr;
+    struct ggml_tensor * time_mix_decay      = nullptr;
+    struct ggml_tensor * time_mix_decay_w1   = nullptr;
+    struct ggml_tensor * time_mix_decay_w2   = nullptr;
+    struct ggml_tensor * time_mix_key        = nullptr;
+    struct ggml_tensor * time_mix_value      = nullptr;
+    struct ggml_tensor * time_mix_receptance = nullptr;
+    struct ggml_tensor * time_mix_gate       = nullptr;
+
+    struct ggml_tensor * time_mix_ln     = nullptr;
+    struct ggml_tensor * time_mix_ln_b   = nullptr;
+    struct ggml_tensor * time_mix_output = nullptr;
+
+    struct ggml_tensor * channel_mix_lerp_k = nullptr;
+    struct ggml_tensor * channel_mix_lerp_r = nullptr;
+
+    struct ggml_tensor * channel_mix_key        = nullptr;
+    struct ggml_tensor * channel_mix_receptance = nullptr;
+    struct ggml_tensor * channel_mix_value      = nullptr;
+
+    // long rope factors
+    struct ggml_tensor * rope_long  = nullptr;
+    struct ggml_tensor * rope_short = nullptr;
+    struct ggml_tensor * rope_freqs = nullptr;
+
+    // bitnet scale
+    struct ggml_tensor * wq_scale       = nullptr;
+    struct ggml_tensor * wk_scale       = nullptr;
+    struct ggml_tensor * wv_scale       = nullptr;
+    struct ggml_tensor * wo_scale       = nullptr;
+    struct ggml_tensor * ffn_gate_scale = nullptr;
+    struct ggml_tensor * ffn_up_scale   = nullptr;
+    struct ggml_tensor * ffn_down_scale = nullptr;
+
+    struct llama_layer_posnet posnet;
+
+    struct llama_layer_convnext convnext;
+};
+
+struct llama_model {
+    llm_type type = MODEL_UNKNOWN;
+    llm_arch arch = LLM_ARCH_UNKNOWN;
+
+    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
+
+    std::string name = "n/a";
+
+    llama_hparams hparams = {};
+    llama_vocab   vocab;
+
+    struct ggml_tensor * tok_embd   = nullptr;
+    struct ggml_tensor * type_embd  = nullptr;
+    struct ggml_tensor * pos_embd   = nullptr;
+    struct ggml_tensor * tok_norm   = nullptr;
+    struct ggml_tensor * tok_norm_b = nullptr;
+
+    struct ggml_tensor * output_norm     = nullptr;
+    struct ggml_tensor * output_norm_b   = nullptr;
+    struct ggml_tensor * output          = nullptr;
+    struct ggml_tensor * output_b        = nullptr;
+    struct ggml_tensor * output_norm_enc = nullptr;
+
+    // classifier
+    struct ggml_tensor * cls       = nullptr;
+    struct ggml_tensor * cls_b     = nullptr;
+    struct ggml_tensor * cls_out   = nullptr;
+    struct ggml_tensor * cls_out_b = nullptr;
+
+    struct ggml_tensor * conv1d   = nullptr;
+    struct ggml_tensor * conv1d_b = nullptr;
+
+    std::vector layers;
+
+    // gguf metadata
+    std::unordered_map gguf_kv;
+
+    llama_split_mode split_mode;
+    int main_gpu;
+    int n_gpu_layers;
+
+    std::vector rpc_servers;
+
+    // list of devices used in this model
+    std::vector devices;
+
+
+    // lists of buffer types used for each layer
+    using buft_list_t = std::vector>;
+    buft_list_t cpu_buft_list;
+    std::map gpu_buft_list;
+
+    struct layer_dev {
+        ggml_backend_dev_t dev;
+        buft_list_t * buft_list;
+    };
+
+    layer_dev dev_input = {};
+    layer_dev dev_output = {};
+    std::vector dev_layer;
+
+    // contexts where the model tensors metadata is stored
+    std::vector ctxs;
+
+    // the model memory buffers for the tensor data
+    std::vector bufs;
+
+    // model memory mapped files
+    llama_mmaps mappings;
+
+    // objects representing data potentially being locked in memory
+    llama_mlocks mlock_bufs;
+    llama_mlocks mlock_mmaps;
+
+    // for quantize-stats only
+    std::vector> tensors_by_name;
+
+    int64_t t_load_us  = 0;
+    int64_t t_start_us = 0;
+
+    // total number of parameters in the model
+    uint64_t n_elements = 0;
+
+    // total size of all the tensors in the model in bytes
+    size_t  n_bytes     = 0;
+};
+
+const char * llm_type_name(llm_type type);
+
+std::string llama_model_arch_name (const llama_model & model);
+std::string llama_model_type_name (const llama_model & model);
+std::string llama_model_ftype_name(const llama_model & model);
+
+// used by llama_adapter_cvec
+ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
+
+// used by llama_adapter_lora
+struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
+
+size_t llama_model_max_nodes(const llama_model & model);
+
+struct llama_model_loader;
+
+// TODO: become llama_model methods
+void llm_load_stats     (llama_model_loader & ml, llama_model & model);
+void llm_load_arch      (llama_model_loader & ml, llama_model & model);
+void llm_load_hparams   (llama_model_loader & ml, llama_model & model);
+void llm_load_vocab     (llama_model_loader & ml, llama_model & model);
+void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
new file mode 100644
index 000000000..466e7bc61
--- /dev/null
+++ b/src/llama-quant.cpp
@@ -0,0 +1,929 @@
+#include "llama-quant.h"
+
+#include "llama-impl.h"
+#include "llama-model.h"
+#include "llama-model-loader.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+static void zeros(std::ofstream & file, size_t n) {
+    char zero = 0;
+    for (size_t i = 0; i < n; ++i) {
+        file.write(&zero, 1);
+    }
+}
+
+struct quantize_state_impl {
+    const llama_model                 & model;
+    const llama_model_quantize_params * params;
+
+    int n_attention_wv = 0;
+    int n_ffn_down     = 0;
+    int n_ffn_gate     = 0;
+    int n_ffn_up       = 0;
+    int i_attention_wv = 0;
+    int i_ffn_down     = 0;
+    int i_ffn_gate     = 0;
+    int i_ffn_up       = 0;
+
+    int n_k_quantized = 0;
+    int n_fallback    = 0;
+
+    bool has_imatrix = false;
+
+    // used to figure out if a model shares tok_embd with the output weight
+    bool has_output = false;
+
+    quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
+        : model(model)
+        , params(params)
+        {}
+};
+
+static void llama_tensor_dequantize_impl(
+    struct ggml_tensor * tensor, std::vector> & output, std::vector & workers,
+    const size_t nelements, const int nthread
+) {
+    if (output.size() < nelements) {
+        output.resize(nelements);
+    }
+    float * f32_output = (float *) output.data();
+
+    const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
+    if (ggml_is_quantized(tensor->type)) {
+        if (qtype->to_float == NULL) {
+            throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
+        }
+    } else if (tensor->type != GGML_TYPE_F16 &&
+               tensor->type != GGML_TYPE_BF16) {
+        throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
+    }
+
+    if (nthread < 2) {
+        if (tensor->type == GGML_TYPE_F16) {
+            ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
+        } else if (tensor->type == GGML_TYPE_BF16) {
+            ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
+        } else if (ggml_is_quantized(tensor->type)) {
+            qtype->to_float(tensor->data, f32_output, nelements);
+        } else {
+            GGML_ABORT("fatal error"); // unreachable
+        }
+        return;
+    }
+
+    size_t block_size;
+    if (tensor->type == GGML_TYPE_F16 ||
+        tensor->type == GGML_TYPE_BF16) {
+        block_size = 1;
+    } else {
+        block_size = (size_t)ggml_blck_size(tensor->type);
+    }
+
+    size_t block_size_bytes = ggml_type_size(tensor->type);
+
+    GGML_ASSERT(nelements % block_size == 0);
+    size_t nblocks = nelements / block_size;
+    size_t blocks_per_thread = nblocks / nthread;
+    size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
+
+    size_t in_buff_offs = 0;
+    size_t out_buff_offs = 0;
+
+    for (int tnum = 0; tnum < nthread; tnum++) {
+        size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
+        size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
+        size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
+
+        auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
+            if (typ == GGML_TYPE_F16) {
+                ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
+            } else if (typ == GGML_TYPE_BF16) {
+                ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
+            } else {
+                qtype->to_float(inbuf, outbuf, nels);
+            }
+        };
+        workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
+        in_buff_offs += thr_block_bytes;
+        out_buff_offs += thr_elems;
+    }
+    for (auto & w : workers) { w.join(); }
+    workers.clear();
+}
+
+static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
+    const std::string name = ggml_get_name(tensor);
+
+    // TODO: avoid hardcoded tensor names - use the TN_* constants
+    const llm_arch arch = qs.model.arch;
+    const auto       tn = LLM_TN(arch);
+
+    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
+        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
+    };
+    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
+    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
+        if (n_expert > 1) {
+            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
+            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
+            // for getting the current layer as I initially thought, and we need to resort to parsing the
+            // tensor name.
+            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
+                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
+            }
+            if (i_layer < 0 || i_layer >= n_layer) {
+                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
+            }
+        }
+        return std::make_pair(i_layer, n_layer);
+    };
+
+    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
+    // with the quantization of the output tensor
+    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
+        if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
+            new_type = qs.params->output_tensor_type;
+        } else {
+            const int64_t nx = tensor->ne[0];
+            const int64_t qk_k = ggml_blck_size(new_type);
+
+            if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
+                new_type = GGML_TYPE_Q8_0;
+            }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
+                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S  || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M   ||
+                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
+                new_type = GGML_TYPE_Q5_K;
+            }
+            else if (new_type != GGML_TYPE_Q8_0) {
+                new_type = GGML_TYPE_Q6_K;
+            }
+        }
+    } else if (name == "token_embd.weight") {
+        if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
+            new_type = qs.params->token_embedding_type;
+        } else {
+            if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
+                ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
+                new_type = GGML_TYPE_Q2_K;
+            }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
+                new_type = GGML_TYPE_IQ3_S;
+            }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
+                new_type = GGML_TYPE_IQ3_S;
+            }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
+                new_type = GGML_TYPE_Q4_K;
+            }
+        }
+    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
+               ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
+        if (name.find("attn_v.weight") != std::string::npos) {
+            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
+            else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
+            ++qs.i_attention_wv;
+        }
+        else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (name.find("ffn_down") != std::string::npos) {
+            if (qs.i_ffn_down < qs.n_ffn_down/8) {
+                new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
+            }
+            ++qs.i_ffn_down;
+        }
+        else if (name.find("attn_output.weight") != std::string::npos) {
+            if (qs.model.hparams.n_expert == 8) {
+                new_type = GGML_TYPE_Q5_K;
+            } else {
+                if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
+            }
+        }
+    } else if (name.find("attn_v.weight") != std::string::npos) {
+        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
+            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
+            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
+        }
+        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
+            new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
+        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
+            new_type = GGML_TYPE_Q5_K;
+        }
+        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
+                use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
+        if (qs.model.type == MODEL_70B) {
+            // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
+            // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
+            // nearly negligible increase in model size by quantizing this tensor with more bits:
+            if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
+        }
+        if (qs.model.hparams.n_expert == 8) {
+            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+            // TODO: explore better strategies
+            new_type = GGML_TYPE_Q8_0;
+        }
+        ++qs.i_attention_wv;
+    } else if (name.find("attn_k.weight") != std::string::npos) {
+        if (qs.model.hparams.n_expert == 8) {
+            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+            // TODO: explore better strategies
+            new_type = GGML_TYPE_Q8_0;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
+            new_type = GGML_TYPE_IQ3_XXS;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
+            new_type = GGML_TYPE_IQ2_S;
+        }
+    } else if (name.find("attn_q.weight") != std::string::npos) {
+        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
+            new_type = GGML_TYPE_IQ3_XXS;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
+            new_type = GGML_TYPE_IQ2_S;
+        }
+    } else if (name.find("ffn_down") != std::string::npos) {
+        auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
+        int i_layer = info.first, n_layer = info.second;
+        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
+            if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
+            new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
+            new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
+                     : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
+                     : GGML_TYPE_Q3_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
+                    (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
+            new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
+            if (arch == LLM_ARCH_FALCON) {
+                new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
+                           use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
+            } else {
+                if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
+            }
+        }
+        else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
+            new_type = GGML_TYPE_Q5_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
+            new_type = GGML_TYPE_Q5_K;
+        }
+        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
+                && qs.has_imatrix && i_layer < n_layer/8) {
+            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
+            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
+            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
+            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
+        }
+        ++qs.i_ffn_down;
+    } else if (name.find("attn_output.weight") != std::string::npos) {
+        if (arch != LLM_ARCH_FALCON) {
+            if (qs.model.hparams.n_expert == 8) {
+                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
+                    ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL  ||
+                    ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S  ||
+                    ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
+                    new_type = GGML_TYPE_Q5_K;
+                }
+            } else {
+                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   ) new_type = GGML_TYPE_Q3_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  ) new_type = GGML_TYPE_Q4_K;
+            }
+        } else {
+            if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
+        }
+    }
+    else if (name.find("attn_qkv.weight") != std::string::npos) {
+        if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
+            new_type = GGML_TYPE_Q4_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
+    }
+    else if (name.find("ffn_gate") != std::string::npos) {
+        auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
+        int i_layer = info.first, n_layer = info.second;
+        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
+            new_type = GGML_TYPE_IQ3_XXS;
+        }
+        ++qs.i_ffn_gate;
+    }
+    else if (name.find("ffn_up") != std::string::npos) {
+        auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
+        int i_layer = info.first, n_layer = info.second;
+        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
+            new_type = GGML_TYPE_IQ3_XXS;
+        }
+        ++qs.i_ffn_up;
+    }
+
+    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+    //}
+    // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
+    //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
+    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+    //}
+    // This can be used to reduce the size of the Q5_K_S model.
+    // The associated PPL increase is fully in line with the size reduction
+    //else {
+    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
+    //}
+    bool convert_incompatible_tensor = false;
+    {
+        const int64_t nx = tensor->ne[0];
+        const int64_t ny = tensor->ne[1];
+        const int64_t qk_k = ggml_blck_size(new_type);
+
+        if (nx % qk_k != 0) {
+            LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
+            convert_incompatible_tensor = true;
+        } else {
+            ++qs.n_k_quantized;
+        }
+    }
+
+    if (convert_incompatible_tensor) {
+        switch (new_type) {
+            case GGML_TYPE_TQ1_0:
+            case GGML_TYPE_TQ2_0:  new_type = GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
+            case GGML_TYPE_IQ2_XXS:
+            case GGML_TYPE_IQ2_XS:
+            case GGML_TYPE_IQ2_S:
+            case GGML_TYPE_IQ3_XXS:
+            case GGML_TYPE_IQ3_S:
+            case GGML_TYPE_IQ1_S:
+            case GGML_TYPE_IQ1_M:
+            case GGML_TYPE_Q2_K:
+            case GGML_TYPE_Q3_K:
+            case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
+            case GGML_TYPE_Q4_K:   new_type = GGML_TYPE_Q5_0;   break;
+            case GGML_TYPE_Q5_K:   new_type = GGML_TYPE_Q5_1;   break;
+            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
+            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
+        }
+        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
+            new_type = GGML_TYPE_F16;
+        }
+        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
+        ++qs.n_fallback;
+    }
+
+    return new_type;
+}
+
+static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
+    if (nthread < 2) {
+        // single-thread
+        size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
+        if (!ggml_validate_row_data(new_type, new_data, new_size)) {
+            throw std::runtime_error("quantized data validation failed");
+        }
+        return new_size;
+    }
+
+    std::mutex mutex;
+    int64_t counter = 0;
+    size_t new_size = 0;
+    bool valid = true;
+    auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
+            nrows, n_per_row, imatrix]() {
+        const int64_t nrows_per_chunk = chunk_size / n_per_row;
+        size_t local_size = 0;
+        while (true) {
+            std::unique_lock lock(mutex);
+            int64_t first_row = counter; counter += nrows_per_chunk;
+            if (first_row >= nrows) {
+                if (local_size > 0) {
+                    new_size += local_size;
+                }
+                break;
+            }
+            lock.unlock();
+            const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
+            size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
+            local_size += this_size;
+
+            // validate the quantized data
+            const size_t row_size  = ggml_row_size(new_type, n_per_row);
+            void * this_data = (char *) new_data + first_row * row_size;
+            if (!ggml_validate_row_data(new_type, this_data, this_size)) {
+                std::unique_lock lock(mutex);
+                valid = false;
+                break;
+            }
+        }
+    };
+    for (int it = 0; it < nthread - 1; ++it) {
+        workers.emplace_back(compute);
+    }
+    compute();
+    for (auto & w : workers) { w.join(); }
+    workers.clear();
+    if (!valid) {
+        throw std::runtime_error("quantized data validation failed");
+    }
+    return new_size;
+}
+
+static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
+    ggml_type default_type;
+    llama_ftype ftype = params->ftype;
+
+    switch (params->ftype) {
+        case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
+        case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
+        case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
+        case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
+        case LLAMA_FTYPE_MOSTLY_F16:  default_type = GGML_TYPE_F16;  break;
+        case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
+        case LLAMA_FTYPE_ALL_F32:     default_type = GGML_TYPE_F32;  break;
+
+        // K-quants
+        case LLAMA_FTYPE_MOSTLY_Q2_K_S:
+        case LLAMA_FTYPE_MOSTLY_Q2_K:    default_type = GGML_TYPE_Q2_K;    break;
+        case LLAMA_FTYPE_MOSTLY_IQ3_XS:  default_type = GGML_TYPE_IQ3_S;   break;
+        case LLAMA_FTYPE_MOSTLY_Q3_K_S:
+        case LLAMA_FTYPE_MOSTLY_Q3_K_M:
+        case LLAMA_FTYPE_MOSTLY_Q3_K_L:  default_type = GGML_TYPE_Q3_K;    break;
+        case LLAMA_FTYPE_MOSTLY_Q4_K_S:
+        case LLAMA_FTYPE_MOSTLY_Q4_K_M:  default_type = GGML_TYPE_Q4_K;    break;
+        case LLAMA_FTYPE_MOSTLY_Q5_K_S:
+        case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = GGML_TYPE_Q5_K;    break;
+        case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = GGML_TYPE_Q6_K;    break;
+        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = GGML_TYPE_TQ1_0;   break;
+        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = GGML_TYPE_TQ2_0;   break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = GGML_TYPE_IQ2_XS;  break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = GGML_TYPE_IQ2_XS;  break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_M:   default_type = GGML_TYPE_IQ2_S;   break;
+        case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
+        case LLAMA_FTYPE_MOSTLY_IQ1_S:   default_type = GGML_TYPE_IQ1_S;   break;
+        case LLAMA_FTYPE_MOSTLY_IQ1_M:   default_type = GGML_TYPE_IQ1_M;   break;
+        case LLAMA_FTYPE_MOSTLY_IQ4_NL:  default_type = GGML_TYPE_IQ4_NL;  break;
+        case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
+        case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
+        case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
+
+        default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
+    }
+
+    int nthread = params->nthread;
+
+    if (nthread <= 0) {
+        nthread = std::thread::hardware_concurrency();
+    }
+
+    // mmap consistently increases speed Linux, and also increases speed on Windows with
+    // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
+#if defined(__linux__) || defined(_WIN32)
+    constexpr bool use_mmap = true;
+#else
+    constexpr bool use_mmap = false;
+#endif
+
+    llama_model_kv_override * kv_overrides = nullptr;
+    if (params->kv_overrides) {
+        auto v = (std::vector*)params->kv_overrides;
+        kv_overrides = v->data();
+    }
+    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
+    ml.init_mappings(false); // no prefetching
+
+    llama_model model;
+    llm_load_arch   (ml, model);
+    llm_load_hparams(ml, model);
+    llm_load_stats  (ml, model);
+
+    struct quantize_state_impl qs(model, params);
+
+    if (params->only_copy) {
+        ftype = model.ftype;
+    }
+    const std::unordered_map> * imatrix_data = nullptr;
+    if (params->imatrix) {
+        imatrix_data = static_cast>*>(params->imatrix);
+        if (imatrix_data) {
+            LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
+            qs.has_imatrix = true;
+            // check imatrix for nans or infs
+            for (const auto & kv : *imatrix_data) {
+                for (float f : kv.second) {
+                    if (!std::isfinite(f)) {
+                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
+                    }
+                }
+            }
+        }
+    }
+
+    const size_t align = GGUF_DEFAULT_ALIGNMENT;
+    gguf_context_ptr ctx_out { gguf_init_empty() };
+
+    // copy the KV pairs from the input file
+    gguf_set_kv     (ctx_out.get(), ml.meta.get());
+    gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
+    gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
+
+    // Remove split metadata
+    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
+    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
+    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
+
+    if (params->kv_overrides) {
+        const std::vector & overrides = *(const std::vector *)params->kv_overrides;
+        for (const auto & o : overrides) {
+            if (o.key[0] == 0) break;
+            if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
+                gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
+            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
+                gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
+            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
+                gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
+            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
+                gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
+            } else {
+                LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
+            }
+        }
+    }
+
+    // make a list of weights
+    std::vector tensors;
+    tensors.reserve(ml.weights_map.size());
+    for (const auto & it : ml.weights_map) {
+        tensors.push_back(&it.second);
+    }
+
+    // keep_split requires that the weights are sorted by split index
+    if (params->keep_split) {
+        std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
+            if (a->idx == b->idx) {
+                return a->offs < b->offs;
+            }
+            return a->idx < b->idx;
+        });
+    }
+
+    for (const auto * it : tensors) {
+        const struct ggml_tensor * tensor = it->tensor;
+
+        const std::string name = ggml_get_name(tensor);
+
+        // TODO: avoid hardcoded tensor names - use the TN_* constants
+        if (name.find("attn_v.weight")   != std::string::npos ||
+            name.find("attn_qkv.weight") != std::string::npos ||
+            name.find("attn_kv_b.weight")!= std::string::npos) {
+            ++qs.n_attention_wv;
+        } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
+            qs.has_output = true;
+        }
+    }
+
+    qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
+
+    // sanity checks
+    {
+        const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
+        // attention layers have a non-zero number of kv heads
+        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
+        if (llama_model_has_encoder(&model)) {
+            n_attn_layer *= 3;
+        }
+        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+    }
+
+    size_t total_size_org = 0;
+    size_t total_size_new = 0;
+
+    std::vector workers;
+    workers.reserve(nthread);
+
+    int idx = 0;
+
+    std::vector> read_data;
+    std::vector> work;
+    std::vector> f32_conv_buf;
+
+    uint16_t n_split = 1;
+
+    // Assume split index is continuous
+    if (params->keep_split) {
+        for (const auto * it : tensors) {
+            n_split = std::max(uint16_t(it->idx + 1), n_split);
+        }
+    }
+    std::vector ctx_outs(n_split);
+    ctx_outs[0] = std::move(ctx_out);
+
+    // populate the original tensors so we get an initial meta data
+    for (const auto * it : tensors) {
+        uint16_t i_split = params->keep_split ? it->idx : 0;
+        struct ggml_tensor * tensor = it->tensor;
+        if (!ctx_outs[i_split]) {
+            ctx_outs[i_split].reset(gguf_init_empty());
+        }
+        gguf_add_tensor(ctx_outs[i_split].get(), tensor);
+    }
+
+    // Set split info if needed
+    if (n_split > 1) {
+        for (size_t i = 0; i < ctx_outs.size(); ++i) {
+            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
+            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
+            gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
+        }
+    }
+
+    int cur_split = -1;
+    std::ofstream fout;
+    auto close_ofstream = [&]() {
+        // Write metadata and close file handler
+        if (fout.is_open()) {
+            fout.seekp(0);
+            std::vector data(gguf_get_meta_size(ctx_outs[cur_split].get()));
+            gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
+            fout.write((const char *) data.data(), data.size());
+            fout.close();
+        }
+    };
+    auto new_ofstream = [&](int index) {
+        cur_split = index;
+        GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
+        std::string fname = fname_out;
+        if (params->keep_split) {
+            std::vector split_path(llama_path_max(), 0);
+            llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
+            fname = std::string(split_path.data());
+        }
+
+        fout = std::ofstream(fname, std::ios::binary);
+        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
+        const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
+        // placeholder for the meta data
+        ::zeros(fout, meta_size);
+    };
+
+    const auto tn = LLM_TN(model.arch);
+    new_ofstream(0);
+    for (const auto * it : tensors) {
+        const auto & weight = *it;
+        struct ggml_tensor * tensor = weight.tensor;
+        if (weight.idx != cur_split && params->keep_split) {
+            close_ofstream();
+            new_ofstream(weight.idx);
+        }
+
+        const std::string name = ggml_get_name(tensor);
+
+        if (!ml.use_mmap) {
+            if (read_data.size() < ggml_nbytes(tensor)) {
+                read_data.resize(ggml_nbytes(tensor));
+            }
+            tensor->data = read_data.data();
+        }
+        ml.load_data_for(tensor);
+
+        LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
+               ++idx, ml.n_tensors,
+               ggml_get_name(tensor),
+               llama_format_tensor_shape(tensor).c_str(),
+               ggml_type_name(tensor->type));
+
+        // This used to be a regex, but  has an extreme cost to compile times.
+        bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+
+        // quantize only 2D and 3D tensors (experts)
+        quantize &= (ggml_n_dims(tensor) >= 2);
+
+        // do not quantize norm tensors
+        quantize &= name.find("_norm.weight") == std::string::npos;
+
+        quantize &= params->quantize_output_tensor || name != "output.weight";
+        quantize &= !params->only_copy;
+
+        // do not quantize expert gating tensors
+        // NOTE: can't use LLM_TN here because the layer number is not known
+        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
+
+        // do not quantize positional embeddings and token types (BERT)
+        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
+        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
+
+        // do not quantize Mamba's small yet 2D weights
+        // NOTE: can't use LLM_TN here because the layer number is not known
+        quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
+
+        // do not quantize RWKV's time_mix_first tensors
+        quantize &= name.find("time_mix_first.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+
+        // do not quantize relative position bias (T5)
+        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
+
+        enum ggml_type new_type;
+        void * new_data;
+        size_t new_size;
+
+        if (quantize) {
+            new_type = default_type;
+
+            // get more optimal quantization type based on the tensor shape, layer, etc.
+            if (!params->pure && ggml_is_quantized(default_type)) {
+                new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
+            }
+            if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
+                new_type = params->token_embedding_type;
+            }
+            if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
+                new_type = params->output_tensor_type;
+            }
+
+            // If we've decided to quantize to the same type the tensor is already
+            // in then there's nothing to do.
+            quantize = tensor->type != new_type;
+        }
+
+        if (!quantize) {
+            new_type = tensor->type;
+            new_data = tensor->data;
+            new_size = ggml_nbytes(tensor);
+            LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
+        } else {
+            const int64_t nelements = ggml_nelements(tensor);
+
+            const float * imatrix = nullptr;
+            if (imatrix_data) {
+                auto it = imatrix_data->find(tensor->name);
+                if (it == imatrix_data->end()) {
+                    LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
+                } else {
+                    if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
+                        imatrix = it->second.data();
+                    } else {
+                        LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
+                                int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
+
+                        // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
+                        // this is a significant error and it may be good idea to abort the process if this happens,
+                        // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
+                        // tok_embd should be ignored in this case, since it always causes this warning
+                        if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
+                            throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
+                                    int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
+                        }
+                    }
+                }
+            }
+            if ((new_type == GGML_TYPE_IQ2_XXS ||
+                 new_type == GGML_TYPE_IQ2_XS  ||
+                 new_type == GGML_TYPE_IQ2_S   ||
+                 new_type == GGML_TYPE_IQ1_S   ||
+                (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight"))  ||
+                (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
+                LLAMA_LOG_ERROR("\n\n============================================================\n");
+                LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
+                LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
+                LLAMA_LOG_ERROR("============================================================\n\n");
+                throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
+            }
+
+            float * f32_data;
+
+            if (tensor->type == GGML_TYPE_F32) {
+                f32_data = (float *) tensor->data;
+            } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
+                throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
+            } else {
+                llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
+                f32_data = (float *) f32_conv_buf.data();
+            }
+
+            LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
+            fflush(stdout);
+
+            if (work.size() < (size_t)nelements * 4) {
+                work.resize(nelements * 4); // upper bound on size
+            }
+            new_data = work.data();
+
+            const int64_t n_per_row = tensor->ne[0];
+            const int64_t nrows = tensor->ne[1];
+
+            static const int64_t min_chunk_size = 32 * 512;
+            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
+
+            const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
+            const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
+            const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
+
+            // quantize each expert separately since they have different importance matrices
+            new_size = 0;
+            for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
+                const float * f32_data_03 = f32_data + i03 * nelements_matrix;
+                void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
+                const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
+
+                new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
+            }
+            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
+        }
+        total_size_org += ggml_nbytes(tensor);
+        total_size_new += new_size;
+
+        // update the gguf meta data as we go
+        gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
+
+        // write tensor data + padding
+        fout.write((const char *) new_data, new_size);
+        zeros(fout, GGML_PAD(new_size, align) - new_size);
+    }
+    close_ofstream();
+
+    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+    if (qs.n_fallback > 0) {
+        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
+                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
+    }
+}
+
+//
+// interface implementation
+//
+
+struct llama_model_quantize_params llama_model_quantize_default_params() {
+    struct llama_model_quantize_params result = {
+        /*.nthread                     =*/ 0,
+        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
+        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
+        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
+        /*.allow_requantize            =*/ false,
+        /*.quantize_output_tensor      =*/ true,
+        /*.only_copy                   =*/ false,
+        /*.pure                        =*/ false,
+        /*.keep_split                  =*/ false,
+        /*.imatrix                     =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+    };
+
+    return result;
+}
+
+uint32_t llama_model_quantize(
+        const char * fname_inp,
+        const char * fname_out,
+        const llama_model_quantize_params * params) {
+    try {
+        llama_model_quantize_impl(fname_inp, fname_out, params);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
+        return 1;
+    }
+
+    return 0;
+}
diff --git a/src/llama-quant.h b/src/llama-quant.h
new file mode 100644
index 000000000..6f70f09be
--- /dev/null
+++ b/src/llama-quant.h
@@ -0,0 +1 @@
+#pragma once
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index bebff77cf..ef5a576cc 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -1,5 +1,6 @@
 #include "llama-sampling.h"
 
+#include "llama-impl.h"
 #include "llama-vocab.h"
 #include "llama-grammar.h"
 
@@ -14,6 +15,118 @@
 #include 
 #include 
 #include 
+#include 
+
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+template
+struct ring_buffer {
+    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
+
+    T & front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    const T & front() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    T & back() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    const T & back() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    void push_back(const T & value) {
+        if (capacity == 0) {
+            throw std::runtime_error("ring buffer: capacity is zero");
+        }
+
+        if (sz == capacity) {
+            // advance the start when buffer is full
+            first = (first + 1) % capacity;
+        } else {
+            sz++;
+        }
+        data[pos] = value;
+        pos = (pos + 1) % capacity;
+    }
+
+    T pop_front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        T value = data[first];
+        first = (first + 1) % capacity;
+        sz--;
+        return value;
+    }
+
+    //T & operator[](size_t i) {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    //const T & at(size_t i) const {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    const T & rat(size_t i) const {
+        if (i >= sz) {
+            throw std::runtime_error("ring buffer: index out of bounds");
+        }
+        return data[(first + sz - i - 1) % capacity];
+    }
+
+    std::vector to_vector() const {
+        std::vector result;
+        result.reserve(sz);
+        for (size_t i = 0; i < sz; i++) {
+            result.push_back(data[(first + i) % capacity]);
+        }
+        return result;
+    }
+
+    void clear() {
+        // here only reset the status of the buffer
+        sz = 0;
+        first = 0;
+        pos = 0;
+    }
+
+    bool empty() const {
+        return sz == 0;
+    }
+
+    size_t size() const {
+        return sz;
+    }
+
+    size_t capacity = 0;
+    size_t sz = 0;
+    size_t first = 0;
+    size_t pos = 0;
+
+    std::vector data;
+};
 
 static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
     // iterator for the probabilities
@@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 const float val = cur_p->data[i].logit;
                 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets-1, ib));
+                ib = std::max(0, std::min(nbuckets - 1, ib));
                 bucket_idx[i] = ib;
                 ++histo[ib];
             }
@@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 int j = bucket_idx[i];
                 if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
+                    *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
                 }
             }
 
             ptr = tmp_tokens.data();
             int ndone = 0;
-            for (int j = nbuckets-1; j > ib; --j) {
+            for (int j = nbuckets - 1; j > ib; --j) {
                 std::sort(ptr, ptr + histo[j], comp);
                 ptr += histo[j];
                 ndone += histo[j];
@@ -1719,7 +1832,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
                 ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
                 if (n > 0) {
                     lt = k;
-                    rt = k+n-1;
+                    rt = k + n - 1;
                 }
             } else {
                 // If k is inside the current Z-box, consider two cases.
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 0a477d6dd..a4c015484 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1,5 +1,7 @@
 #include "llama-vocab.h"
 
+#include "llama-impl.h"
+
 #include "unicode.h"
 
 #include 
@@ -16,22 +18,6 @@
 // helpers
 //
 
-LLAMA_ATTRIBUTE_FORMAT(1, 2)
-static std::string format(const char * fmt, ...) {
-    va_list ap;
-    va_list ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
-    std::vector buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
 struct naive_trie {
     naive_trie() : has_value(false), value(0) {
     }
@@ -396,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "\\p{N}+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
+                regex_exprs = {
+                    "\\p{N}{1,3}",
+                    "[一-龥぀-ゟ゠-ヿ]+",
+                    "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
                 regex_exprs = {
                     "[\r\n]",
@@ -504,7 +497,7 @@ struct llm_tokenizer_bpe_session {
 
     bool append_bos(std::vector & output) const {
         if (vocab.tokenizer_add_bos) {
-            GGML_ASSERT(vocab.special_bos_id != -1);
+            GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
             output.push_back(vocab.special_bos_id);
             return true;
         }
@@ -513,7 +506,7 @@ struct llm_tokenizer_bpe_session {
 
     bool append_eos(std::vector & output) const {
         if (vocab.tokenizer_add_eos) {
-            GGML_ASSERT(vocab.special_eos_id != -1);
+            GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
             output.push_back(vocab.special_eos_id);
             return true;
         }
@@ -1410,7 +1403,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
 
                         // repeat for the right side
@@ -1424,7 +1417,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
                         break;
                     }
@@ -1461,7 +1454,7 @@ std::vector llama_tokenize_internal(
                 bool is_prev_special = true;  // prefix with space if first token
 
                 if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_bos_id);
                     is_prev_special = true;
                 }
@@ -1496,7 +1489,7 @@ std::vector llama_tokenize_internal(
                 }
 
                 if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
+                    GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
@@ -1529,7 +1522,7 @@ std::vector llama_tokenize_internal(
         case LLAMA_VOCAB_TYPE_WPM:
             {
                 if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
+                    GGML_ASSERT(vocab.special_cls_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_cls_id);
                 }
 
@@ -1549,14 +1542,14 @@ std::vector llama_tokenize_internal(
                 }
 
                 if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
+                    GGML_ASSERT(vocab.special_sep_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_sep_id);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_UGM:
             {
                 if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_bos_id);
                 }
                 llm_tokenizer_ugm_session session(vocab);
@@ -1581,7 +1574,7 @@ std::vector llama_tokenize_internal(
                 }
 
                 if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
+                    GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
@@ -1649,7 +1642,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
 }
 
 bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && vocab.special_eog_ids.count(token) > 0;
+    return token != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(token) > 0;
 }
 
 bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1888,7 +1881,7 @@ int32_t llama_detokenize_impl(
     }
 
     if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == vocab.special_eos_id) {
             n_tokens--;
         }
     }
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index a9b0da5ef..0d00086da 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -1,6 +1,6 @@
 #pragma once
 
-#include "llama-impl.h"
+#include "llama.h"
 
 #include 
 #include 
@@ -8,6 +8,18 @@
 #include 
 #include 
 
+static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
+}
+
 struct llm_tokenizer;
 
 struct llama_vocab {
diff --git a/src/llama.cpp b/src/llama.cpp
index 4d41602fe..ae375bcd3 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -1,7472 +1,45 @@
 #include "llama-impl.h"
+
+#include "llama-chat.h"
+#include "llama-mmap.h"
+#include "llama-context.h"
 #include "llama-vocab.h"
 #include "llama-sampling.h"
-
-#include "unicode.h"
+#include "llama-kv-cache.h"
+#include "llama-model-loader.h"
+#include "llama-model.h"
 
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 #include "ggml-cpp.h"
 
-// TODO: replace with ggml API call
-#define QK_K 256
-
-#ifdef __has_include
-    #if __has_include()
-        #include 
-        #if defined(_POSIX_MAPPED_FILES)
-            #include 
-            #include 
-        #endif
-        #if defined(_POSIX_MEMLOCK_RANGE)
-            #include 
-        #endif
-    #endif
-#endif
-
-#if defined(_WIN32)
-    #define WIN32_LEAN_AND_MEAN
-    #ifndef NOMINMAX
-        #define NOMINMAX
-    #endif
-    #include 
-    #ifndef PATH_MAX
-        #define PATH_MAX MAX_PATH
-    #endif
-    #include 
-#endif
-
-#if __cplusplus >= 202000L
-    #define LU8(x) (const char*)(u8##x)
-#else
-    #define LU8(x) u8##x
-#endif
-
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-// bump if necessary
-#define LLAMA_MAX_LAYERS  512
-#define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
-
 //
-// helpers
+// tensor loading (TODO: add llama_tesor_loader?)
 //
 
-// trim whitespace from the beginning and end of a string
-static std::string trim(const std::string & str) {
-    size_t start = 0;
-    size_t end = str.size();
-    while (start < end && isspace(str[start])) {
-        start += 1;
-    }
-    while (end > start && isspace(str[end - 1])) {
-        end -= 1;
-    }
-    return str.substr(start, end - start);
-}
-
-static bool is_float_close(float a, float b, float abs_tol) {
-    // Check for non-negative tolerance
-    if (abs_tol < 0.0) {
-        throw std::invalid_argument("Tolerance must be non-negative");
-    }
-
-    // Exact equality check
-    if (a == b) {
-        return true;
-    }
-
-    // Check for infinities
-    if (std::isinf(a) || std::isinf(b)) {
-        return false;
-    }
-
-    // Regular comparison using the provided absolute tolerance
-    return std::fabs(b - a) <= abs_tol;
-}
-
-static void zeros(std::ofstream & file, size_t n) {
-    char zero = 0;
-    for (size_t i = 0; i < n; ++i) {
-        file.write(&zero, 1);
-    }
-}
-
-LLAMA_ATTRIBUTE_FORMAT(1, 2)
-static std::string format(const char * fmt, ...) {
-    va_list ap;
-    va_list ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
-    std::vector buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
-//
-// gguf constants (sync with gguf.py)
-//
-
-enum llm_arch {
-    LLM_ARCH_LLAMA,
-    LLM_ARCH_DECI,
-    LLM_ARCH_FALCON,
-    LLM_ARCH_BAICHUAN,
-    LLM_ARCH_GROK,
-    LLM_ARCH_GPT2,
-    LLM_ARCH_GPTJ,
-    LLM_ARCH_GPTNEOX,
-    LLM_ARCH_MPT,
-    LLM_ARCH_STARCODER,
-    LLM_ARCH_REFACT,
-    LLM_ARCH_BERT,
-    LLM_ARCH_NOMIC_BERT,
-    LLM_ARCH_JINA_BERT_V2,
-    LLM_ARCH_BLOOM,
-    LLM_ARCH_STABLELM,
-    LLM_ARCH_QWEN,
-    LLM_ARCH_QWEN2,
-    LLM_ARCH_QWEN2MOE,
-    LLM_ARCH_QWEN2VL,
-    LLM_ARCH_PHI2,
-    LLM_ARCH_PHI3,
-    LLM_ARCH_PLAMO,
-    LLM_ARCH_CODESHELL,
-    LLM_ARCH_ORION,
-    LLM_ARCH_INTERNLM2,
-    LLM_ARCH_MINICPM,
-    LLM_ARCH_MINICPM3,
-    LLM_ARCH_GEMMA,
-    LLM_ARCH_GEMMA2,
-    LLM_ARCH_STARCODER2,
-    LLM_ARCH_MAMBA,
-    LLM_ARCH_XVERSE,
-    LLM_ARCH_COMMAND_R,
-    LLM_ARCH_DBRX,
-    LLM_ARCH_OLMO,
-    LLM_ARCH_OLMO2,
-    LLM_ARCH_OLMOE,
-    LLM_ARCH_OPENELM,
-    LLM_ARCH_ARCTIC,
-    LLM_ARCH_DEEPSEEK,
-    LLM_ARCH_DEEPSEEK2,
-    LLM_ARCH_CHATGLM,
-    LLM_ARCH_BITNET,
-    LLM_ARCH_T5,
-    LLM_ARCH_T5ENCODER,
-    LLM_ARCH_JAIS,
-    LLM_ARCH_NEMOTRON,
-    LLM_ARCH_EXAONE,
-    LLM_ARCH_RWKV6,
-    LLM_ARCH_GRANITE,
-    LLM_ARCH_GRANITE_MOE,
-    LLM_ARCH_CHAMELEON,
-    LLM_ARCH_WAVTOKENIZER_DEC,
-    LLM_ARCH_UNKNOWN,
-};
-
-static const std::map LLM_ARCH_NAMES = {
-    { LLM_ARCH_LLAMA,            "llama"            },
-    { LLM_ARCH_DECI,             "deci"            },
-    { LLM_ARCH_FALCON,           "falcon"           },
-    { LLM_ARCH_GROK,             "grok"             },
-    { LLM_ARCH_GPT2,             "gpt2"             },
-    { LLM_ARCH_GPTJ,             "gptj"             },
-    { LLM_ARCH_GPTNEOX,          "gptneox"          },
-    { LLM_ARCH_MPT,              "mpt"              },
-    { LLM_ARCH_BAICHUAN,         "baichuan"         },
-    { LLM_ARCH_STARCODER,        "starcoder"        },
-    { LLM_ARCH_REFACT,           "refact"           },
-    { LLM_ARCH_BERT,             "bert"             },
-    { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
-    { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
-    { LLM_ARCH_BLOOM,            "bloom"            },
-    { LLM_ARCH_STABLELM,         "stablelm"         },
-    { LLM_ARCH_QWEN,             "qwen"             },
-    { LLM_ARCH_QWEN2,            "qwen2"            },
-    { LLM_ARCH_QWEN2MOE,         "qwen2moe"         },
-    { LLM_ARCH_QWEN2VL,          "qwen2vl"          },
-    { LLM_ARCH_PHI2,             "phi2"             },
-    { LLM_ARCH_PHI3,             "phi3"             },
-    { LLM_ARCH_PLAMO,            "plamo"            },
-    { LLM_ARCH_CODESHELL,        "codeshell"        },
-    { LLM_ARCH_ORION,            "orion"            },
-    { LLM_ARCH_INTERNLM2,        "internlm2"        },
-    { LLM_ARCH_MINICPM,          "minicpm"          },
-    { LLM_ARCH_MINICPM3,         "minicpm3"         },
-    { LLM_ARCH_GEMMA,            "gemma"            },
-    { LLM_ARCH_GEMMA2,           "gemma2"           },
-    { LLM_ARCH_STARCODER2,       "starcoder2"       },
-    { LLM_ARCH_MAMBA,            "mamba"            },
-    { LLM_ARCH_XVERSE,           "xverse"           },
-    { LLM_ARCH_COMMAND_R,        "command-r"        },
-    { LLM_ARCH_DBRX,             "dbrx"             },
-    { LLM_ARCH_OLMO,             "olmo"             },
-    { LLM_ARCH_OLMO2,            "olmo2"            },
-    { LLM_ARCH_OLMOE,            "olmoe"            },
-    { LLM_ARCH_OPENELM,          "openelm"          },
-    { LLM_ARCH_ARCTIC,           "arctic"           },
-    { LLM_ARCH_DEEPSEEK,         "deepseek"         },
-    { LLM_ARCH_DEEPSEEK2,        "deepseek2"        },
-    { LLM_ARCH_CHATGLM,          "chatglm"          },
-    { LLM_ARCH_BITNET,           "bitnet"           },
-    { LLM_ARCH_T5,               "t5"               },
-    { LLM_ARCH_T5ENCODER,        "t5encoder"        },
-    { LLM_ARCH_JAIS,             "jais"             },
-    { LLM_ARCH_NEMOTRON,         "nemotron"         },
-    { LLM_ARCH_EXAONE,           "exaone"           },
-    { LLM_ARCH_RWKV6,            "rwkv6"            },
-    { LLM_ARCH_GRANITE,          "granite"          },
-    { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
-    { LLM_ARCH_CHAMELEON,        "chameleon"        },
-    { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
-    { LLM_ARCH_UNKNOWN,          "(unknown)"        },
-};
-
-enum llm_kv {
-    LLM_KV_GENERAL_TYPE,
-    LLM_KV_GENERAL_ARCHITECTURE,
-    LLM_KV_GENERAL_QUANTIZATION_VERSION,
-    LLM_KV_GENERAL_ALIGNMENT,
-    LLM_KV_GENERAL_NAME,
-    LLM_KV_GENERAL_AUTHOR,
-    LLM_KV_GENERAL_VERSION,
-    LLM_KV_GENERAL_URL,
-    LLM_KV_GENERAL_DESCRIPTION,
-    LLM_KV_GENERAL_LICENSE,
-    LLM_KV_GENERAL_SOURCE_URL,
-    LLM_KV_GENERAL_SOURCE_HF_REPO,
-
-    LLM_KV_VOCAB_SIZE,
-    LLM_KV_CONTEXT_LENGTH,
-    LLM_KV_EMBEDDING_LENGTH,
-    LLM_KV_FEATURES_LENGTH,
-    LLM_KV_BLOCK_COUNT,
-    LLM_KV_LEADING_DENSE_BLOCK_COUNT,
-    LLM_KV_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
-    LLM_KV_USE_PARALLEL_RESIDUAL,
-    LLM_KV_TENSOR_DATA_LAYOUT,
-    LLM_KV_EXPERT_COUNT,
-    LLM_KV_EXPERT_USED_COUNT,
-    LLM_KV_EXPERT_SHARED_COUNT,
-    LLM_KV_EXPERT_WEIGHTS_SCALE,
-    LLM_KV_POOLING_TYPE,
-    LLM_KV_LOGIT_SCALE,
-    LLM_KV_DECODER_START_TOKEN_ID,
-    LLM_KV_ATTN_LOGIT_SOFTCAPPING,
-    LLM_KV_FINAL_LOGIT_SOFTCAPPING,
-    LLM_KV_SWIN_NORM,
-    LLM_KV_RESCALE_EVERY_N_LAYERS,
-    LLM_KV_TIME_MIX_EXTRA_DIM,
-    LLM_KV_TIME_DECAY_EXTRA_DIM,
-    LLM_KV_RESIDUAL_SCALE,
-    LLM_KV_EMBEDDING_SCALE,
-
-    LLM_KV_ATTENTION_HEAD_COUNT,
-    LLM_KV_ATTENTION_HEAD_COUNT_KV,
-    LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
-    LLM_KV_ATTENTION_CLAMP_KQV,
-    LLM_KV_ATTENTION_KEY_LENGTH,
-    LLM_KV_ATTENTION_VALUE_LENGTH,
-    LLM_KV_ATTENTION_LAYERNORM_EPS,
-    LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
-    LLM_KV_ATTENTION_GROUPNORM_EPS,
-    LLM_KV_ATTENTION_GROUPNORM_GROUPS,
-    LLM_KV_ATTENTION_CAUSAL,
-    LLM_KV_ATTENTION_Q_LORA_RANK,
-    LLM_KV_ATTENTION_KV_LORA_RANK,
-    LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
-    LLM_KV_ATTENTION_SLIDING_WINDOW,
-    LLM_KV_ATTENTION_SCALE,
-
-    LLM_KV_ROPE_DIMENSION_COUNT,
-    LLM_KV_ROPE_DIMENSION_SECTIONS,
-    LLM_KV_ROPE_FREQ_BASE,
-    LLM_KV_ROPE_SCALE_LINEAR,
-    LLM_KV_ROPE_SCALING_TYPE,
-    LLM_KV_ROPE_SCALING_FACTOR,
-    LLM_KV_ROPE_SCALING_ATTN_FACTOR,
-    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
-    LLM_KV_ROPE_SCALING_FINETUNED,
-    LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
-
-    LLM_KV_SPLIT_NO,
-    LLM_KV_SPLIT_COUNT,
-    LLM_KV_SPLIT_TENSORS_COUNT,
-
-    LLM_KV_SSM_INNER_SIZE,
-    LLM_KV_SSM_CONV_KERNEL,
-    LLM_KV_SSM_STATE_SIZE,
-    LLM_KV_SSM_TIME_STEP_RANK,
-    LLM_KV_SSM_DT_B_C_RMS,
-
-    LLM_KV_WKV_HEAD_SIZE,
-
-    LLM_KV_TOKENIZER_MODEL,
-    LLM_KV_TOKENIZER_PRE,
-    LLM_KV_TOKENIZER_LIST,
-    LLM_KV_TOKENIZER_TOKEN_TYPE,
-    LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
-    LLM_KV_TOKENIZER_SCORES,
-    LLM_KV_TOKENIZER_MERGES,
-    LLM_KV_TOKENIZER_BOS_ID,
-    LLM_KV_TOKENIZER_EOS_ID,
-    LLM_KV_TOKENIZER_EOT_ID,
-    LLM_KV_TOKENIZER_EOM_ID,
-    LLM_KV_TOKENIZER_UNK_ID,
-    LLM_KV_TOKENIZER_SEP_ID,
-    LLM_KV_TOKENIZER_PAD_ID,
-    LLM_KV_TOKENIZER_CLS_ID,
-    LLM_KV_TOKENIZER_MASK_ID,
-    LLM_KV_TOKENIZER_ADD_BOS,
-    LLM_KV_TOKENIZER_ADD_EOS,
-    LLM_KV_TOKENIZER_ADD_PREFIX,
-    LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
-    LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
-    LLM_KV_TOKENIZER_HF_JSON,
-    LLM_KV_TOKENIZER_RWKV,
-    LLM_KV_TOKENIZER_FIM_PRE_ID,
-    LLM_KV_TOKENIZER_FIM_SUF_ID,
-    LLM_KV_TOKENIZER_FIM_MID_ID,
-    LLM_KV_TOKENIZER_FIM_PAD_ID,
-    LLM_KV_TOKENIZER_FIM_REP_ID,
-    LLM_KV_TOKENIZER_FIM_SEP_ID,
-
-    LLM_KV_ADAPTER_TYPE,
-    LLM_KV_ADAPTER_LORA_ALPHA,
-
-    LLM_KV_POSNET_EMBEDDING_LENGTH,
-    LLM_KV_POSNET_BLOCK_COUNT,
-
-    LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
-    LLM_KV_CONVNEXT_BLOCK_COUNT,
-
-    // deprecated:
-    LLM_KV_TOKENIZER_PREFIX_ID,
-    LLM_KV_TOKENIZER_SUFFIX_ID,
-    LLM_KV_TOKENIZER_MIDDLE_ID,
-};
-
-static const std::map LLM_KV_NAMES = {
-    { LLM_KV_GENERAL_TYPE,                  "general.type"                          },
-    { LLM_KV_GENERAL_ARCHITECTURE,          "general.architecture"                  },
-    { LLM_KV_GENERAL_QUANTIZATION_VERSION,  "general.quantization_version"          },
-    { LLM_KV_GENERAL_ALIGNMENT,             "general.alignment"                     },
-    { LLM_KV_GENERAL_NAME,                  "general.name"                          },
-    { LLM_KV_GENERAL_AUTHOR,                "general.author"                        },
-    { LLM_KV_GENERAL_VERSION,               "general.version"                       },
-    { LLM_KV_GENERAL_URL,                   "general.url"                           },
-    { LLM_KV_GENERAL_DESCRIPTION,           "general.description"                   },
-    { LLM_KV_GENERAL_LICENSE,               "general.license"                       },
-    { LLM_KV_GENERAL_SOURCE_URL,            "general.source.url"                    },
-    { LLM_KV_GENERAL_SOURCE_HF_REPO,        "general.source.huggingface.repository" },
-
-    { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
-    { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },
-    { LLM_KV_EMBEDDING_LENGTH,                  "%s.embedding_length"                  },
-    { LLM_KV_FEATURES_LENGTH,                   "%s.features_length"                   },
-    { LLM_KV_BLOCK_COUNT,                       "%s.block_count"                       },
-    { LLM_KV_LEADING_DENSE_BLOCK_COUNT,         "%s.leading_dense_block_count"         },
-    { LLM_KV_FEED_FORWARD_LENGTH,               "%s.feed_forward_length"               },
-    { LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        "%s.expert_feed_forward_length"        },
-    { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
-    { LLM_KV_USE_PARALLEL_RESIDUAL,             "%s.use_parallel_residual"             },
-    { LLM_KV_TENSOR_DATA_LAYOUT,                "%s.tensor_data_layout"                },
-    { LLM_KV_EXPERT_COUNT,                      "%s.expert_count"                      },
-    { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 },
-    { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               },
-    { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
-    { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
-    { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
-    { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
-    { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
-    { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
-    { LLM_KV_SWIN_NORM,                         "%s.swin_norm"                         },
-    { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
-    { LLM_KV_TIME_MIX_EXTRA_DIM,                "%s.time_mix_extra_dim"                },
-    { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              },
-    { LLM_KV_RESIDUAL_SCALE,                    "%s.residual_scale"                    },
-    { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
-
-    { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
-    { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
-    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,         "%s.attention.max_alibi_bias"         },
-    { LLM_KV_ATTENTION_CLAMP_KQV,              "%s.attention.clamp_kqv"              },
-    { LLM_KV_ATTENTION_KEY_LENGTH,             "%s.attention.key_length"             },
-    { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
-    { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
-    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
-    { LLM_KV_ATTENTION_GROUPNORM_EPS,          "%s.attention.group_norm_epsilon"     },
-    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,       "%s.attention.group_norm_groups"      },
-    { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
-    { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
-    { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
-    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
-    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
-    { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
-
-    { LLM_KV_ROPE_DIMENSION_COUNT,             "%s.rope.dimension_count"                 },
-    { LLM_KV_ROPE_DIMENSION_SECTIONS,          "%s.rope.dimension_sections"              },
-    { LLM_KV_ROPE_FREQ_BASE,                   "%s.rope.freq_base"                       },
-    { LLM_KV_ROPE_SCALE_LINEAR,                "%s.rope.scale_linear"                    },
-    { LLM_KV_ROPE_SCALING_TYPE,                "%s.rope.scaling.type"                    },
-    { LLM_KV_ROPE_SCALING_FACTOR,              "%s.rope.scaling.factor"                  },
-    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,         "%s.rope.scaling.attn_factor"             },
-    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,        "%s.rope.scaling.original_context_length" },
-    { LLM_KV_ROPE_SCALING_FINETUNED,           "%s.rope.scaling.finetuned"               },
-    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL,        "%s.rope.scaling.yarn_log_multiplier"     },
-
-    { LLM_KV_SPLIT_NO,                         "split.no"            },
-    { LLM_KV_SPLIT_COUNT,                      "split.count"         },
-    { LLM_KV_SPLIT_TENSORS_COUNT,              "split.tensors.count" },
-
-    { LLM_KV_SSM_CONV_KERNEL,                  "%s.ssm.conv_kernel"    },
-    { LLM_KV_SSM_INNER_SIZE,                   "%s.ssm.inner_size"     },
-    { LLM_KV_SSM_STATE_SIZE,                   "%s.ssm.state_size"     },
-    { LLM_KV_SSM_TIME_STEP_RANK,               "%s.ssm.time_step_rank" },
-    { LLM_KV_SSM_DT_B_C_RMS,                   "%s.ssm.dt_b_c_rms"     },
-
-    { LLM_KV_WKV_HEAD_SIZE,                    "%s.wkv.head_size" },
-
-    { LLM_KV_POSNET_EMBEDDING_LENGTH,          "%s.posnet.embedding_length" },
-    { LLM_KV_POSNET_BLOCK_COUNT,               "%s.posnet.block_count"      },
-
-    { LLM_KV_CONVNEXT_EMBEDDING_LENGTH,        "%s.convnext.embedding_length" },
-    { LLM_KV_CONVNEXT_BLOCK_COUNT,             "%s.convnext.block_count"      },
-
-    { LLM_KV_TOKENIZER_MODEL,                  "tokenizer.ggml.model"                    },
-    { LLM_KV_TOKENIZER_PRE,                    "tokenizer.ggml.pre"                      },
-    { LLM_KV_TOKENIZER_LIST,                   "tokenizer.ggml.tokens"                   },
-    { LLM_KV_TOKENIZER_TOKEN_TYPE,             "tokenizer.ggml.token_type"               },
-    { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,       "tokenizer.ggml.token_type_count"         },
-    { LLM_KV_TOKENIZER_SCORES,                 "tokenizer.ggml.scores"                   },
-    { LLM_KV_TOKENIZER_MERGES,                 "tokenizer.ggml.merges"                   },
-    { LLM_KV_TOKENIZER_BOS_ID,                 "tokenizer.ggml.bos_token_id"             },
-    { LLM_KV_TOKENIZER_EOS_ID,                 "tokenizer.ggml.eos_token_id"             },
-    { LLM_KV_TOKENIZER_EOT_ID,                 "tokenizer.ggml.eot_token_id"             },
-    { LLM_KV_TOKENIZER_EOM_ID,                 "tokenizer.ggml.eom_token_id"             },
-    { LLM_KV_TOKENIZER_UNK_ID,                 "tokenizer.ggml.unknown_token_id"         },
-    { LLM_KV_TOKENIZER_SEP_ID,                 "tokenizer.ggml.seperator_token_id"       },
-    { LLM_KV_TOKENIZER_PAD_ID,                 "tokenizer.ggml.padding_token_id"         },
-    { LLM_KV_TOKENIZER_CLS_ID,                 "tokenizer.ggml.cls_token_id"             },
-    { LLM_KV_TOKENIZER_MASK_ID,                "tokenizer.ggml.mask_token_id"            },
-    { LLM_KV_TOKENIZER_ADD_BOS,                "tokenizer.ggml.add_bos_token"            },
-    { LLM_KV_TOKENIZER_ADD_EOS,                "tokenizer.ggml.add_eos_token"            },
-    { LLM_KV_TOKENIZER_ADD_PREFIX,             "tokenizer.ggml.add_space_prefix"         },
-    { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,        "tokenizer.ggml.remove_extra_whitespaces" },
-    { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,   "tokenizer.ggml.precompiled_charsmap"     },
-    { LLM_KV_TOKENIZER_HF_JSON,                "tokenizer.huggingface.json"              },
-    { LLM_KV_TOKENIZER_RWKV,                   "tokenizer.rwkv.world"                    },
-    { LLM_KV_TOKENIZER_FIM_PRE_ID,             "tokenizer.ggml.fim_pre_token_id"         },
-    { LLM_KV_TOKENIZER_FIM_SUF_ID,             "tokenizer.ggml.fim_suf_token_id"         },
-    { LLM_KV_TOKENIZER_FIM_MID_ID,             "tokenizer.ggml.fim_mid_token_id"         },
-    { LLM_KV_TOKENIZER_FIM_PAD_ID,             "tokenizer.ggml.fim_pad_token_id"         },
-    { LLM_KV_TOKENIZER_FIM_REP_ID,             "tokenizer.ggml.fim_rep_token_id"         },
-    { LLM_KV_TOKENIZER_FIM_SEP_ID,             "tokenizer.ggml.fim_sep_token_id"         },
-
-    { LLM_KV_ADAPTER_TYPE,                     "adapter.type"       },
-    { LLM_KV_ADAPTER_LORA_ALPHA,               "adapter.lora.alpha" },
-
-    // deprecated
-    { LLM_KV_TOKENIZER_PREFIX_ID,              "tokenizer.ggml.prefix_token_id" },
-    { LLM_KV_TOKENIZER_SUFFIX_ID,              "tokenizer.ggml.suffix_token_id" },
-    { LLM_KV_TOKENIZER_MIDDLE_ID,              "tokenizer.ggml.middle_token_id" },
-};
-
-struct LLM_KV {
-    LLM_KV(llm_arch arch) : arch(arch) {}
-
-    llm_arch arch;
-
-    std::string operator()(llm_kv kv) const {
-        return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
-    }
-};
-
-enum llm_tensor {
-    LLM_TENSOR_TOKEN_EMBD,
-    LLM_TENSOR_TOKEN_EMBD_NORM,
-    LLM_TENSOR_TOKEN_TYPES,
-    LLM_TENSOR_POS_EMBD,
-    LLM_TENSOR_OUTPUT,
-    LLM_TENSOR_OUTPUT_NORM,
-    LLM_TENSOR_ROPE_FREQS,
-    LLM_TENSOR_ROPE_FACTORS_LONG,
-    LLM_TENSOR_ROPE_FACTORS_SHORT,
-    LLM_TENSOR_ATTN_Q,
-    LLM_TENSOR_ATTN_K,
-    LLM_TENSOR_ATTN_V,
-    LLM_TENSOR_ATTN_QKV,
-    LLM_TENSOR_ATTN_OUT,
-    LLM_TENSOR_ATTN_NORM,
-    LLM_TENSOR_ATTN_NORM_2,
-    LLM_TENSOR_ATTN_OUT_NORM,
-    LLM_TENSOR_ATTN_POST_NORM,
-    LLM_TENSOR_ATTN_ROT_EMBD,
-    LLM_TENSOR_FFN_GATE_INP,
-    LLM_TENSOR_FFN_GATE_INP_SHEXP,
-    LLM_TENSOR_FFN_NORM,
-    LLM_TENSOR_FFN_POST_NORM,
-    LLM_TENSOR_FFN_GATE,
-    LLM_TENSOR_FFN_DOWN,
-    LLM_TENSOR_FFN_UP,
-    LLM_TENSOR_FFN_ACT,
-    LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
-    LLM_TENSOR_FFN_GATE_EXP,
-    LLM_TENSOR_FFN_UP_EXP,
-    LLM_TENSOR_FFN_NORM_EXPS,
-    LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
-    LLM_TENSOR_FFN_GATE_EXPS,
-    LLM_TENSOR_FFN_UP_EXPS,
-    LLM_TENSOR_FFN_DOWN_SHEXP,
-    LLM_TENSOR_FFN_GATE_SHEXP,
-    LLM_TENSOR_FFN_UP_SHEXP,
-    LLM_TENSOR_ATTN_Q_NORM,
-    LLM_TENSOR_ATTN_K_NORM,
-    LLM_TENSOR_LAYER_OUT_NORM,
-    LLM_TENSOR_SSM_IN,
-    LLM_TENSOR_SSM_CONV1D,
-    LLM_TENSOR_SSM_X,
-    LLM_TENSOR_SSM_DT,
-    LLM_TENSOR_SSM_A,
-    LLM_TENSOR_SSM_D,
-    LLM_TENSOR_SSM_OUT,
-    LLM_TENSOR_TIME_MIX_W1,
-    LLM_TENSOR_TIME_MIX_W2,
-    LLM_TENSOR_TIME_MIX_LERP_X,
-    LLM_TENSOR_TIME_MIX_LERP_W,
-    LLM_TENSOR_TIME_MIX_LERP_K,
-    LLM_TENSOR_TIME_MIX_LERP_V,
-    LLM_TENSOR_TIME_MIX_LERP_R,
-    LLM_TENSOR_TIME_MIX_LERP_G,
-    LLM_TENSOR_TIME_MIX_FIRST,
-    LLM_TENSOR_TIME_MIX_DECAY,
-    LLM_TENSOR_TIME_MIX_DECAY_W1,
-    LLM_TENSOR_TIME_MIX_DECAY_W2,
-    LLM_TENSOR_TIME_MIX_KEY,
-    LLM_TENSOR_TIME_MIX_VALUE,
-    LLM_TENSOR_TIME_MIX_RECEPTANCE,
-    LLM_TENSOR_TIME_MIX_GATE,
-    LLM_TENSOR_TIME_MIX_LN,
-    LLM_TENSOR_TIME_MIX_OUTPUT,
-    LLM_TENSOR_CHANNEL_MIX_LERP_K,
-    LLM_TENSOR_CHANNEL_MIX_LERP_R,
-    LLM_TENSOR_CHANNEL_MIX_KEY,
-    LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
-    LLM_TENSOR_CHANNEL_MIX_VALUE,
-    LLM_TENSOR_ATTN_Q_A,
-    LLM_TENSOR_ATTN_Q_B,
-    LLM_TENSOR_ATTN_KV_A_MQA,
-    LLM_TENSOR_ATTN_KV_B,
-    LLM_TENSOR_ATTN_Q_A_NORM,
-    LLM_TENSOR_ATTN_KV_A_NORM,
-    LLM_TENSOR_ATTN_SUB_NORM,
-    LLM_TENSOR_FFN_SUB_NORM,
-    LLM_TENSOR_DEC_ATTN_NORM,
-    LLM_TENSOR_DEC_ATTN_Q,
-    LLM_TENSOR_DEC_ATTN_K,
-    LLM_TENSOR_DEC_ATTN_V,
-    LLM_TENSOR_DEC_ATTN_OUT,
-    LLM_TENSOR_DEC_ATTN_REL_B,
-    LLM_TENSOR_DEC_CROSS_ATTN_NORM,
-    LLM_TENSOR_DEC_CROSS_ATTN_Q,
-    LLM_TENSOR_DEC_CROSS_ATTN_K,
-    LLM_TENSOR_DEC_CROSS_ATTN_V,
-    LLM_TENSOR_DEC_CROSS_ATTN_OUT,
-    LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
-    LLM_TENSOR_DEC_FFN_NORM,
-    LLM_TENSOR_DEC_FFN_GATE,
-    LLM_TENSOR_DEC_FFN_DOWN,
-    LLM_TENSOR_DEC_FFN_UP,
-    LLM_TENSOR_DEC_OUTPUT_NORM,
-    LLM_TENSOR_ENC_ATTN_NORM,
-    LLM_TENSOR_ENC_ATTN_Q,
-    LLM_TENSOR_ENC_ATTN_K,
-    LLM_TENSOR_ENC_ATTN_V,
-    LLM_TENSOR_ENC_ATTN_OUT,
-    LLM_TENSOR_ENC_ATTN_REL_B,
-    LLM_TENSOR_ENC_FFN_NORM,
-    LLM_TENSOR_ENC_FFN_GATE,
-    LLM_TENSOR_ENC_FFN_DOWN,
-    LLM_TENSOR_ENC_FFN_UP,
-    LLM_TENSOR_ENC_OUTPUT_NORM,
-    LLM_TENSOR_CLS,
-    LLM_TENSOR_CLS_OUT,
-    LLM_TENSOR_CONV1D,
-    LLM_TENSOR_CONVNEXT_DW,
-    LLM_TENSOR_CONVNEXT_NORM,
-    LLM_TENSOR_CONVNEXT_PW1,
-    LLM_TENSOR_CONVNEXT_PW2,
-    LLM_TENSOR_CONVNEXT_GAMMA,
-    LLM_TENSOR_POS_NET_CONV1,
-    LLM_TENSOR_POS_NET_CONV2,
-    LLM_TENSOR_POS_NET_NORM,
-    LLM_TENSOR_POS_NET_NORM1,
-    LLM_TENSOR_POS_NET_NORM2,
-    LLM_TENSOR_POS_NET_ATTN_NORM,
-    LLM_TENSOR_POS_NET_ATTN_Q,
-    LLM_TENSOR_POS_NET_ATTN_K,
-    LLM_TENSOR_POS_NET_ATTN_V,
-    LLM_TENSOR_POS_NET_ATTN_OUT,
-};
-
-static const std::map> LLM_TENSOR_NAMES = {
-    {
-        LLM_ARCH_LLAMA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_DECI,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_BAICHUAN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_FALCON,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_NORM_2,     "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_GROK,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-        },
-    },
-    {
-        LLM_ARCH_GPT2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_GPTJ,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-        },
-    },
-    {
-        LLM_ARCH_GPTNEOX,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MPT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output"},
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_ACT,         "blk.%d.ffn.act" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm"},
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm"},
-        },
-    },
-    {
-        LLM_ARCH_STARCODER,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_REFACT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_BERT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_CLS,             "cls" },
-            { LLM_TENSOR_CLS_OUT,         "cls.output" },
-        },
-    },
-    {
-        LLM_ARCH_NOMIC_BERT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_JINA_BERT_V2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_ATTN_NORM_2,     "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_CLS,             "cls" },
-        },
-    },
-    {
-        LLM_ARCH_BLOOM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_STABLELM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN2VL,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN2MOE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
-            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
-            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
-            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-        },
-    },
-    {
-        LLM_ARCH_PHI2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_PHI3,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
-            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,           "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_PLAMO,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_CODESHELL,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_ORION,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_INTERNLM2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MINICPM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
-            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-        },
-    },
-    {
-        LLM_ARCH_MINICPM3,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
-            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q_A_NORM,      "blk.%d.attn_q_a_norm" },
-            { LLM_TENSOR_ATTN_KV_A_NORM,     "blk.%d.attn_kv_a_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_Q_A,           "blk.%d.attn_q_a" },
-            { LLM_TENSOR_ATTN_Q_B,           "blk.%d.attn_q_b" },
-            { LLM_TENSOR_ATTN_KV_A_MQA,      "blk.%d.attn_kv_a_mqa" },
-            { LLM_TENSOR_ATTN_KV_B,          "blk.%d.attn_kv_b" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_GEMMA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_GEMMA2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
-        },
-    },
-    {
-        LLM_ARCH_STARCODER2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MAMBA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
-            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
-            { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" },
-            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
-            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
-            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
-            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
-        },
-    },
-    {
-        LLM_ARCH_XVERSE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_COMMAND_R,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-        },
-    },
-    {
-        LLM_ARCH_DBRX,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_OLMO,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_OLMO2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_OLMOE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_OPENELM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_ARCTIC,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_NORM_EXPS,   "blk.%d.ffn_norm_exps" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_DEEPSEEK,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ROPE_FREQS,         "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,      "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
-            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
-            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
-            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-        },
-    },
-    {
-        LLM_ARCH_DEEPSEEK2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q_A_NORM,      "blk.%d.attn_q_a_norm" },
-            { LLM_TENSOR_ATTN_KV_A_NORM,     "blk.%d.attn_kv_a_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_Q_A,           "blk.%d.attn_q_a" },
-            { LLM_TENSOR_ATTN_Q_B,           "blk.%d.attn_q_b" },
-            { LLM_TENSOR_ATTN_KV_A_MQA,      "blk.%d.attn_kv_a_mqa" },
-            { LLM_TENSOR_ATTN_KV_B,          "blk.%d.attn_kv_b" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
-            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
-            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
-            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-        },
-    },
-    {
-        LLM_ARCH_CHATGLM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_BITNET,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_SUB_NORM,      "blk.%d.attn_sub_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_SUB_NORM,       "blk.%d.ffn_sub_norm" },
-        },
-    },
-    {
-        LLM_ARCH_T5,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
-            { LLM_TENSOR_OUTPUT,               "output" },
-            { LLM_TENSOR_DEC_OUTPUT_NORM,      "dec.output_norm" },
-            { LLM_TENSOR_DEC_ATTN_NORM,        "dec.blk.%d.attn_norm" },
-            { LLM_TENSOR_DEC_ATTN_Q,           "dec.blk.%d.attn_q" },
-            { LLM_TENSOR_DEC_ATTN_K,           "dec.blk.%d.attn_k" },
-            { LLM_TENSOR_DEC_ATTN_V,           "dec.blk.%d.attn_v" },
-            { LLM_TENSOR_DEC_ATTN_OUT,         "dec.blk.%d.attn_o" },
-            { LLM_TENSOR_DEC_ATTN_REL_B,       "dec.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "dec.blk.%d.cross_attn_norm" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_Q,     "dec.blk.%d.cross_attn_q" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_K,     "dec.blk.%d.cross_attn_k" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_V,     "dec.blk.%d.cross_attn_v" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_OUT,   "dec.blk.%d.cross_attn_o" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
-            { LLM_TENSOR_DEC_FFN_NORM,         "dec.blk.%d.ffn_norm" },
-            { LLM_TENSOR_DEC_FFN_GATE,         "dec.blk.%d.ffn_gate" },
-            { LLM_TENSOR_DEC_FFN_DOWN,         "dec.blk.%d.ffn_down" },
-            { LLM_TENSOR_DEC_FFN_UP,           "dec.blk.%d.ffn_up" },
-            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
-            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
-            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
-            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
-            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
-            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
-            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
-            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
-            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
-            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_T5ENCODER,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
-            { LLM_TENSOR_OUTPUT,               "output" },
-            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
-            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
-            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
-            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
-            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
-            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
-            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
-            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
-            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
-            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_JAIS,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_NEMOTRON,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_EXAONE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_RWKV6,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
-            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
-            { LLM_TENSOR_OUTPUT,                    "output" },
-            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
-            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
-            { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" },
-            { LLM_TENSOR_TIME_MIX_LERP_W,           "blk.%d.time_mix_lerp_w" },
-            { LLM_TENSOR_TIME_MIX_LERP_K,           "blk.%d.time_mix_lerp_k" },
-            { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" },
-            { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" },
-            { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" },
-            { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
-            { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
-            { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
-            { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" },
-            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
-            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
-            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
-            { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" },
-            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
-            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
-            { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" },
-            { LLM_TENSOR_CHANNEL_MIX_LERP_R,        "blk.%d.channel_mix_lerp_r" },
-            { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" },
-            { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" },
-            { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" },
-        },
-    },
-    {
-        LLM_ARCH_GRANITE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_GRANITE_MOE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_CHAMELEON,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-        },
-    },
-    {
-        LLM_ARCH_WAVTOKENIZER_DEC,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,        "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM,   "token_embd_norm" },
-            { LLM_TENSOR_CONV1D,            "conv1d" },
-            { LLM_TENSOR_CONVNEXT_DW,       "convnext.%d.dw" },
-            { LLM_TENSOR_CONVNEXT_NORM,     "convnext.%d.norm" },
-            { LLM_TENSOR_CONVNEXT_PW1,      "convnext.%d.pw1" },
-            { LLM_TENSOR_CONVNEXT_PW2,      "convnext.%d.pw2" },
-            { LLM_TENSOR_CONVNEXT_GAMMA,    "convnext.%d.gamma" },
-            { LLM_TENSOR_OUTPUT_NORM,       "output_norm" },
-            { LLM_TENSOR_OUTPUT,            "output" },
-            { LLM_TENSOR_POS_NET_CONV1,     "posnet.%d.conv1" },
-            { LLM_TENSOR_POS_NET_CONV2,     "posnet.%d.conv2" },
-            { LLM_TENSOR_POS_NET_NORM,      "posnet.%d.norm" },
-            { LLM_TENSOR_POS_NET_NORM1,     "posnet.%d.norm1" },
-            { LLM_TENSOR_POS_NET_NORM2,     "posnet.%d.norm2" },
-            { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" },
-            { LLM_TENSOR_POS_NET_ATTN_Q,    "posnet.%d.attn_q" },
-            { LLM_TENSOR_POS_NET_ATTN_K,    "posnet.%d.attn_k" },
-            { LLM_TENSOR_POS_NET_ATTN_V,    "posnet.%d.attn_v" },
-            { LLM_TENSOR_POS_NET_ATTN_OUT,  "posnet.%d.attn_output" },
-        },
-    },
-    {
-        LLM_ARCH_UNKNOWN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-        },
-    },
-};
-
-enum llm_chat_template {
-    LLM_CHAT_TEMPLATE_CHATML,
-    LLM_CHAT_TEMPLATE_LLAMA_2,
-    LLM_CHAT_TEMPLATE_LLAMA_2_SYS,
-    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS,
-    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP,
-    LLM_CHAT_TEMPLATE_MISTRAL_V1,
-    LLM_CHAT_TEMPLATE_MISTRAL_V3,
-    LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
-    LLM_CHAT_TEMPLATE_MISTRAL_V7,
-    LLM_CHAT_TEMPLATE_PHI_3,
-    LLM_CHAT_TEMPLATE_FALCON_3,
-    LLM_CHAT_TEMPLATE_ZEPHYR,
-    LLM_CHAT_TEMPLATE_MONARCH,
-    LLM_CHAT_TEMPLATE_GEMMA,
-    LLM_CHAT_TEMPLATE_ORION,
-    LLM_CHAT_TEMPLATE_OPENCHAT,
-    LLM_CHAT_TEMPLATE_VICUNA,
-    LLM_CHAT_TEMPLATE_VICUNA_ORCA,
-    LLM_CHAT_TEMPLATE_DEEPSEEK,
-    LLM_CHAT_TEMPLATE_DEEPSEEK_2,
-    LLM_CHAT_TEMPLATE_COMMAND_R,
-    LLM_CHAT_TEMPLATE_LLAMA_3,
-    LLM_CHAT_TEMPLATE_CHATGML_3,
-    LLM_CHAT_TEMPLATE_CHATGML_4,
-    LLM_CHAT_TEMPLATE_MINICPM,
-    LLM_CHAT_TEMPLATE_EXAONE_3,
-    LLM_CHAT_TEMPLATE_RWKV_WORLD,
-    LLM_CHAT_TEMPLATE_GRANITE,
-    LLM_CHAT_TEMPLATE_GIGACHAT,
-    LLM_CHAT_TEMPLATE_MEGREZ,
-    LLM_CHAT_TEMPLATE_UNKNOWN,
-};
-
-static const std::map LLM_CHAT_TEMPLATES = {
-    { "chatml",            LLM_CHAT_TEMPLATE_CHATML            },
-    { "llama2",            LLM_CHAT_TEMPLATE_LLAMA_2           },
-    { "llama2-sys",        LLM_CHAT_TEMPLATE_LLAMA_2_SYS       },
-    { "llama2-sys-bos",    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS   },
-    { "llama2-sys-strip",  LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
-    { "mistral-v1",        LLM_CHAT_TEMPLATE_MISTRAL_V1        },
-    { "mistral-v3",        LLM_CHAT_TEMPLATE_MISTRAL_V3        },
-    { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
-    { "mistral-v7",        LLM_CHAT_TEMPLATE_MISTRAL_V7        },
-    { "phi3",              LLM_CHAT_TEMPLATE_PHI_3             },
-    { "falcon3",           LLM_CHAT_TEMPLATE_FALCON_3          },
-    { "zephyr",            LLM_CHAT_TEMPLATE_ZEPHYR            },
-    { "monarch",           LLM_CHAT_TEMPLATE_MONARCH           },
-    { "gemma",             LLM_CHAT_TEMPLATE_GEMMA             },
-    { "orion",             LLM_CHAT_TEMPLATE_ORION             },
-    { "openchat",          LLM_CHAT_TEMPLATE_OPENCHAT          },
-    { "vicuna",            LLM_CHAT_TEMPLATE_VICUNA            },
-    { "vicuna-orca",       LLM_CHAT_TEMPLATE_VICUNA_ORCA       },
-    { "deepseek",          LLM_CHAT_TEMPLATE_DEEPSEEK          },
-    { "deepseek2",         LLM_CHAT_TEMPLATE_DEEPSEEK_2        },
-    { "command-r",         LLM_CHAT_TEMPLATE_COMMAND_R         },
-    { "llama3",            LLM_CHAT_TEMPLATE_LLAMA_3           },
-    { "chatglm3",          LLM_CHAT_TEMPLATE_CHATGML_3         },
-    { "chatglm4",          LLM_CHAT_TEMPLATE_CHATGML_4         },
-    { "minicpm",           LLM_CHAT_TEMPLATE_MINICPM           },
-    { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
-    { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
-    { "granite",           LLM_CHAT_TEMPLATE_GRANITE           },
-    { "gigachat",          LLM_CHAT_TEMPLATE_GIGACHAT          },
-    { "megrez",            LLM_CHAT_TEMPLATE_MEGREZ            },
-};
-
-static llm_arch llm_arch_from_string(const std::string & name) {
-    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
-        if (kv.second == name) {
-            return kv.first;
-        }
-    }
-
-    return LLM_ARCH_UNKNOWN;
-}
-
-// helper to handle gguf constants
-// usage:
-//
-//   const auto tn = LLM_TN(LLM_ARCH_LLAMA);
-//
-//   std::string name = tn(LLM_TENSOR_OUTPUT);                     -> "output"
-//   std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias");         -> "token_embd.bias"
-//   std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3);     -> "blk.3.attn_norm.weight"
-//
-struct LLM_TN_IMPL {
-    const llm_arch arch;
-    const llm_tensor tensor;
-    const char * const suffix;
-    const int bid;
-    const int xid;
-
-    std::string str() const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-
-        std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid);
-
-        if (suffix != nullptr) {
-            name += ".";
-            name += suffix;
-        }
-
-        return name;
-    }
-
-    operator std::string() const {
-        return str();
-    }
-
-    friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
-        return str == tn.str();
-    }
-
-    friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
-        return str != tn.str();
-    }
-};
-
-struct LLM_TN {
-    LLM_TN(llm_arch arch) : arch(arch) {}
-
-    llm_arch arch;
-
-    LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
-        return { arch, tensor, suffix, bid, xid };
-    }
-
-    LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
-        return { arch, tensor, nullptr, bid, xid };
-    }
-};
-
-//
-// gguf helpers
-//
-
-static const std::map LLAMA_ROPE_SCALING_TYPES = {
-    { LLAMA_ROPE_SCALING_TYPE_NONE,       "none"       },
-    { LLAMA_ROPE_SCALING_TYPE_LINEAR,     "linear"     },
-    { LLAMA_ROPE_SCALING_TYPE_YARN,       "yarn"       },
-    { LLAMA_ROPE_SCALING_TYPE_LONGROPE,   "longrope"   },
-};
-
-static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
-    for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
-        if (kv.second == name) {
-            return (llama_rope_scaling_type) kv.first;
-        }
-    }
-
-    return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
-}
-
-static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
-    switch (type) {
-        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
-        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
-        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
-        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
-        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
-        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
-        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
-        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
-        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
-        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
-        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
-        default:                return format("unknown type %d", type);
-    }
-}
-
-static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
-    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
-
-    switch (type) {
-        case GGUF_TYPE_STRING:
-            return gguf_get_val_str(ctx_gguf, i);
-        case GGUF_TYPE_ARRAY:
-            {
-                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
-                int arr_n = gguf_get_arr_n(ctx_gguf, i);
-                const void * data = gguf_get_arr_data(ctx_gguf, i);
-                std::stringstream ss;
-                ss << "[";
-                for (int j = 0; j < arr_n; j++) {
-                    if (arr_type == GGUF_TYPE_STRING) {
-                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
-                        // escape quotes
-                        replace_all(val, "\\", "\\\\");
-                        replace_all(val, "\"", "\\\"");
-                        ss << '"' << val << '"';
-                    } else if (arr_type == GGUF_TYPE_ARRAY) {
-                        ss << "???";
-                    } else {
-                        ss << gguf_data_to_str(arr_type, data, j);
-                    }
-                    if (j < arr_n - 1) {
-                        ss << ", ";
-                    }
-                }
-                ss << "]";
-                return ss.str();
-            }
-        default:
-            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
-    }
-}
-
-//
-// llama helpers
-//
-
-#if defined(_WIN32)
-static std::string llama_format_win_err(DWORD err) {
-    LPSTR buf;
-    size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                 NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
-    if (!size) {
-        return "FormatMessageA failed";
-    }
-    std::string ret(buf, size);
-    LocalFree(buf);
-    return ret;
-}
-#endif
-
-template 
-struct no_init {
-    T value;
-    no_init() { /* do nothing */ }
-};
-
-struct llama_file {
-
-#if defined(_WIN32)
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    HANDLE fp_win32;
-    size_t size;
-
-private:
-    std::string GetErrorMessageWin32(DWORD error_code) const {
-        std::string ret;
-        LPSTR lpMsgBuf = NULL;
-        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                    NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-        if (!bufLen) {
-            ret = format("Win32 error code: %lx", error_code);
-        } else {
-            ret = lpMsgBuf;
-            LocalFree(lpMsgBuf);
-        }
-
-        return ret;
-    }
-
-public:
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-        // SetFilePointerEx returns the current position when seeking relative 0 bytes
-        LARGE_INTEGER li;
-        li.QuadPart = 0;
-        BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-
-        return li.QuadPart;
-    }
-
-    void seek(size_t offset, int whence) const {
-        // no need to convert SEEK_* to FILE_*. The enums are the same.
-        // Still, keep static asserts to avoid failures in the future.
-        static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
-        static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
-        static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
-
-        LARGE_INTEGER li;
-        li.QuadPart = offset;
-        BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus
-        // use the Win32 API to do file io instead of the C/C++ library functions.
-
-        // There are conditions under which ReadFile cannot read chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_read = 0;
-        while (bytes_read < len) {
-            size_t chunk_size = std::min(len - bytes_read, 64*1024*1024);
-            DWORD chunk_read = 0;
-            BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
-            if (!result) {
-                throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_read < chunk_size || chunk_read == 0) {
-                throw std::runtime_error("unexpectedly reached end of file");
-            }
-
-            bytes_read += chunk_read;
-        } ;
-    }
-
-    uint32_t read_u32() const {
-        uint32_t val;
-        read_raw(&val, sizeof(val));
-        return val;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        // There are conditions under which WriteFile cannot write chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_written = 0;
-        while (bytes_written < len) {
-            size_t chunk_size = std::min(len - bytes_written, 64*1024*1024);
-            DWORD chunk_written = 0;
-            BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
-            if (!result) {
-                throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_written < chunk_size || chunk_written == 0) {
-                throw std::runtime_error("unexpectedly failed to write bytes");
-            }
-
-            bytes_written += chunk_written;
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#else
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    size_t size;
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-#ifdef _WIN32
-        __int64 ret = _ftelli64(fp);
-#else
-        long ret = std::ftell(fp);
-#endif
-        if (ret == -1) {
-            throw std::runtime_error(format("ftell error: %s", strerror(errno)));
-        }
-
-        return (size_t) ret;
-    }
-
-    void seek(size_t offset, int whence) const {
-#ifdef _WIN32
-        int ret = _fseeki64(fp, (__int64) offset, whence);
-#else
-        int ret = std::fseek(fp, (long) offset, whence);
-#endif
-        if (ret != 0) {
-            throw std::runtime_error(format("seek error: %s", strerror(errno)));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        std::size_t ret = std::fread(ptr, len, 1, fp);
-        if (ferror(fp)) {
-            throw std::runtime_error(format("read error: %s", strerror(errno)));
-        }
-        if (ret != 1) {
-            throw std::runtime_error("unexpectedly reached end of file");
-        }
-    }
-
-    uint32_t read_u32() const {
-        uint32_t ret;
-        read_raw(&ret, sizeof(ret));
-        return ret;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        size_t ret = std::fwrite(ptr, len, 1, fp);
-        if (ret != 1) {
-            throw std::runtime_error(format("write error: %s", strerror(errno)));
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#endif
-};
-using llama_files = std::vector>;
-
-struct llama_mmap {
-    void * addr;
-    size_t size;
-
-    llama_mmap(const llama_mmap &) = delete;
-
-#ifdef _POSIX_MAPPED_FILES
-    static constexpr bool SUPPORTED = true;
-
-    // list of mapped fragments (first_offset, last_offset)
-    std::vector> mapped_fragments;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) {
-        size = file->size;
-        int fd = fileno(file->fp);
-        int flags = MAP_SHARED;
-        // prefetch/readahead impairs performance on NUMA systems
-        if (numa)  { prefetch = 0; }
-#ifdef __linux__
-        // advise the kernel to read the file sequentially (increases readahead)
-        if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
-            LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
-                    strerror(errno));
-        }
-        if (prefetch) { flags |= MAP_POPULATE; }
-#endif
-        addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
-        if (addr == MAP_FAILED) { // NOLINT
-            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
-        }
-
-        if (prefetch > 0) {
-            // advise the kernel to preload the mapped memory
-            if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-        if (numa) {
-            // advise the kernel not to use readahead
-            // (because the next page might not belong on the same node)
-            if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-
-        // initialize list of mapped_fragments
-        mapped_fragments.emplace_back(0, file->size);
-    }
-
-    static void align_range(size_t * first, size_t * last, size_t page_size) {
-        // align first to the next page
-        size_t offset_in_page = *first & (page_size - 1);
-        size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
-        *first += offset_to_page;
-
-        // align last to the previous page
-        *last = *last & ~(page_size - 1);
-
-        if (*last <= *first) {
-            *last = *first;
-        }
-    }
-
-    // partially unmap the file in the range [first, last)
-    void unmap_fragment(size_t first, size_t last) {
-        // note: this function must not be called multiple times with overlapping ranges
-        // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings
-        int page_size = sysconf(_SC_PAGESIZE);
-        align_range(&first, &last, page_size);
-        size_t len = last - first;
-
-        if (len == 0) {
-            return;
-        }
-
-        GGML_ASSERT(first % page_size == 0);
-        GGML_ASSERT(last % page_size == 0);
-        GGML_ASSERT(last > first);
-
-        void * next_page_start = (uint8_t *) addr + first;
-
-        // unmap the range
-        if (munmap(next_page_start, len)) {
-            LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-        }
-
-        // update the list of mapped fragments to avoid unmapping the same range again in the destructor
-        std::vector> new_mapped_fragments;
-        for (const auto & frag : mapped_fragments) {
-            if (frag.first < first && frag.second > last) {
-                // the range is in the middle of the fragment, split it
-                new_mapped_fragments.emplace_back(frag.first, first);
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first < first && frag.second > first) {
-                // the range starts in the middle of the fragment
-                new_mapped_fragments.emplace_back(frag.first, first);
-            } else if (frag.first < last && frag.second > last) {
-                // the range ends in the middle of the fragment
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first >= first && frag.second <= last) {
-                // the range covers the entire fragment
-            } else {
-                // the range is outside the fragment
-                new_mapped_fragments.push_back(frag);
-            }
-        }
-        mapped_fragments = std::move(new_mapped_fragments);
-    }
-
-    ~llama_mmap() {
-        for (const auto & frag : mapped_fragments) {
-            if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
-                LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-            }
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false) {
-        GGML_UNUSED(numa);
-
-        size = file->size;
-
-        HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
-
-        HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
-
-        if (hMapping == NULL) {
-            DWORD error = GetLastError();
-            throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
-        DWORD error = GetLastError();
-        CloseHandle(hMapping);
-
-        if (addr == NULL) {
-            throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        if (prefetch > 0) {
-#if _WIN32_WINNT >= 0x602
-            // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it
-            BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
-            HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
-
-            // may fail on pre-Windows 8 systems
-            pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
-
-            if (pPrefetchVirtualMemory) {
-                // advise the kernel to preload the mapped memory
-                WIN32_MEMORY_RANGE_ENTRY range;
-                range.VirtualAddress = addr;
-                range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
-                if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
-                    LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
-                            llama_format_win_err(GetLastError()).c_str());
-                }
-            }
-#else
-            throw std::runtime_error("PrefetchVirtualMemory unavailable");
-#endif
-        }
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        // not supported
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-    }
-
-    ~llama_mmap() {
-        if (!UnmapViewOfFile(addr)) {
-            LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false) {
-        GGML_UNUSED(file);
-        GGML_UNUSED(prefetch);
-        GGML_UNUSED(numa);
-
-        throw std::runtime_error("mmap not supported");
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-
-        throw std::runtime_error("mmap not supported");
-    }
-#endif
-};
-using llama_mmaps = std::vector>;
-
-// Represents some region of memory being locked using mlock or VirtualLock;
-// will automatically unlock on destruction.
-struct llama_mlock {
-    void * addr = NULL;
-    size_t size = 0;
-
-    bool failed_already = false;
-
-    llama_mlock() {}
-    llama_mlock(const llama_mlock &) = delete;
-
-    ~llama_mlock() {
-        if (size) {
-            raw_unlock(addr, size);
-        }
-    }
-
-    void init(void * ptr) {
-        GGML_ASSERT(addr == NULL && size == 0); // NOLINT
-        addr = ptr;
-    }
-
-    void grow_to(size_t target_size) {
-        GGML_ASSERT(addr);
-        if (failed_already) {
-            return;
-        }
-        size_t granularity = lock_granularity();
-        target_size = (target_size + granularity - 1) & ~(granularity - 1);
-        if (target_size > size) {
-            if (raw_lock((uint8_t *) addr + size, target_size - size)) {
-                size = target_size;
-            } else {
-                failed_already = true;
-            }
-        }
-    }
-
-#ifdef _POSIX_MEMLOCK_RANGE
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        return (size_t) sysconf(_SC_PAGESIZE);
-    }
-
-    #ifdef __APPLE__
-        #define MLOCK_SUGGESTION \
-            "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
-            "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
-    #else
-        #define MLOCK_SUGGESTION \
-            "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
-    #endif
-
-    bool raw_lock(const void * addr, size_t size) const {
-        if (!mlock(addr, size)) {
-            return true;
-        }
-
-        char* errmsg = std::strerror(errno);
-        bool suggest = (errno == ENOMEM);
-
-        // Check if the resource limit is fine after all
-        struct rlimit lock_limit;
-        if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
-            suggest = false;
-        }
-        if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
-            suggest = false;
-        }
-
-        LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
-                size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
-        return false;
-    }
-
-    #undef MLOCK_SUGGESTION
-
-    static void raw_unlock(void * addr, size_t size) {
-        if (munlock(addr, size)) {
-            LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        SYSTEM_INFO si;
-        GetSystemInfo(&si);
-        return (size_t) si.dwPageSize;
-    }
-
-    bool raw_lock(void * ptr, size_t len) const {
-        for (int tries = 1; ; tries++) {
-            if (VirtualLock(ptr, len)) {
-                return true;
-            }
-            if (tries == 2) {
-                LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
-                    len, size, llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-
-            // It failed but this was only the first try; increase the working
-            // set size and try again.
-            SIZE_T min_ws_size, max_ws_size;
-            if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
-                LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-            // Per MSDN: "The maximum number of pages that a process can lock
-            // is equal to the number of pages in its minimum working set minus
-            // a small overhead."
-            // Hopefully a megabyte is enough overhead:
-            size_t increment = len + 1048576;
-            // The minimum must be <= the maximum, so we need to increase both:
-            min_ws_size += increment;
-            max_ws_size += increment;
-            if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
-                LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-        }
-    }
-
-    static void raw_unlock(void * ptr, size_t len) {
-        if (!VirtualUnlock(ptr, len)) {
-            LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    static size_t lock_granularity() {
-        return (size_t) 65536;
-    }
-
-    bool raw_lock(const void * addr, size_t len) const {
-        LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
-        return false;
-    }
-
-    static void raw_unlock(const void * addr, size_t len) {}
-#endif
-};
-using llama_mlocks = std::vector>;
-
-// NOTE: avoid ever using this except for building the token_to_piece caches
-static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
-    std::string piece;
-    piece.resize(piece.capacity());  // using string internal cache
-    const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
-    if (n_chars < 0) {
-        piece.resize(-n_chars);
-        int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
-        GGML_ASSERT(check == -n_chars);
-    }
-    else {
-        piece.resize(n_chars);
-    }
-
-    return piece;
-}
-
-//
-// globals
-//
-
-struct llama_logger_state {
-    ggml_log_callback log_callback = llama_log_callback_default;
-    void * log_callback_user_data = nullptr;
-};
-
-static llama_logger_state g_logger_state;
-
-// available llama models
-enum e_model {
-    MODEL_UNKNOWN,
-    MODEL_14M,
-    MODEL_17M,
-    MODEL_22M,
-    MODEL_33M,
-    MODEL_60M,
-    MODEL_70M,
-    MODEL_80M,
-    MODEL_109M,
-    MODEL_137M,
-    MODEL_160M,
-    MODEL_220M,
-    MODEL_250M,
-    MODEL_270M,
-    MODEL_335M,
-    MODEL_410M,
-    MODEL_450M,
-    MODEL_770M,
-    MODEL_780M,
-    MODEL_0_5B,
-    MODEL_1B,
-    MODEL_1_3B,
-    MODEL_1_4B,
-    MODEL_1_5B,
-    MODEL_1_6B,
-    MODEL_2B,
-    MODEL_2_8B,
-    MODEL_3B,
-    MODEL_4B,
-    MODEL_6B,
-    MODEL_6_9B,
-    MODEL_7B,
-    MODEL_8B,
-    MODEL_9B,
-    MODEL_11B,
-    MODEL_12B,
-    MODEL_13B,
-    MODEL_14B,
-    MODEL_15B,
-    MODEL_16B,
-    MODEL_20B,
-    MODEL_30B,
-    MODEL_32B,
-    MODEL_34B,
-    MODEL_35B,
-    MODEL_40B,
-    MODEL_65B,
-    MODEL_70B,
-    MODEL_236B,
-    MODEL_314B,
-    MODEL_SMALL,
-    MODEL_MEDIUM,
-    MODEL_LARGE,
-    MODEL_XL,
-    MODEL_A1_7B,
-    MODEL_A2_7B,
-    MODEL_8x7B,
-    MODEL_8x22B,
-    MODEL_16x12B,
-    MODEL_10B_128x3_66B,
-    MODEL_57B_A14B,
-    MODEL_27B,
-};
-
-static const size_t kiB = 1024;
-static const size_t MiB = 1024*kiB;
-static const size_t GiB = 1024*MiB;
-
-struct llama_hparams_posnet {
-    uint32_t n_embd;
-    uint32_t n_layer;
-};
-
-struct llama_hparams_convnext {
-    uint32_t n_embd;
-    uint32_t n_layer;
-};
-
-struct llama_hparams {
-    bool vocab_only;
-    bool rope_finetuned;
-    bool use_par_res;
-    bool swin_norm;
-
-    uint32_t n_vocab = 0;
-    uint32_t n_ctx_train; // context size the model was trained on
-    uint32_t n_embd;
-    uint32_t n_embd_features = 0;
-    uint32_t n_layer;
-    uint32_t n_rot;
-    uint32_t n_swa = 0; // sliding window attention (SWA)
-    uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
-    uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
-    uint32_t n_expert = 0;
-    uint32_t n_expert_used = 0;
-    uint32_t n_vocab_type = 0; // for BERT-style token types
-    uint32_t n_rel_attn_bkts = 0;
-
-    // for WavTokenizer
-    struct llama_hparams_posnet   posnet;
-    struct llama_hparams_convnext convnext;
-
-    std::array n_head_arr;
-    std::array n_head_kv_arr;
-    std::array n_ff_arr;
-
-    uint32_t n_layer_dense_lead = 0;
-    uint32_t n_lora_q = 0;
-    uint32_t n_lora_kv = 0;
-    uint32_t n_ff_exp = 0;
-    uint32_t n_ff_shexp = 0;
-    uint32_t n_expert_shared = 0;
-    float    expert_weights_scale = 0.0;
-
-    float f_norm_eps;
-    float f_norm_rms_eps;
-    float f_norm_group_eps;
-
-    uint32_t n_norm_groups;
-
-    float f_attn_logit_softcapping = 50.0f;
-    float f_final_logit_softcapping = 30.0f;
-
-    // for RWKV
-    uint32_t rescale_every_n_layers = 0;
-    uint32_t time_mix_extra_dim = 0;
-    uint32_t time_decay_extra_dim = 0;
-    uint32_t wkv_head_size = 0;
-
-    float     rope_attn_factor = 1.0f;
-    float     rope_freq_base_train;
-    float     rope_freq_scale_train;
-    uint32_t  n_ctx_orig_yarn;
-    float     rope_yarn_log_mul;
-    int       rope_sections[4];
-
-    // for State Space Models
-    uint32_t ssm_d_conv  = 0;
-    uint32_t ssm_d_inner = 0;
-    uint32_t ssm_d_state = 0;
-    uint32_t ssm_dt_rank = 0;
-    bool ssm_dt_b_c_rms = false;
-
-    float f_clamp_kqv      = 0.0f;
-    float f_max_alibi_bias = 0.0f;
-    float f_logit_scale    = 0.0f;
-
-    // Additional scale factors (Granite/Granite MoE)
-    float f_residual_scale  = 0.0f;
-    float f_embedding_scale = 0.0f;
-    float f_attention_scale = 0.0f;
-
-    bool causal_attn   = true;
-    bool use_alibi     = false;
-    bool attn_soft_cap = false;
-
-    // needed by encoder-decoder models (e.g. T5, FLAN-T5)
-    // ref: https://github.com/ggerganov/llama.cpp/pull/8141
-    llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
-
-    enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
-    enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
-    enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
-
-    uint32_t n_head(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_head_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_head_kv(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_head_kv_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_ff(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_ff_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_gqa(uint32_t il = 0) const {
-        const uint32_t n_head    = this->n_head(il);
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        if (n_head_kv == 0) {
-            return 0;
-        }
-
-        return n_head/n_head_kv;
-    }
-
-    uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        return n_embd_head_k * n_head_kv;
-    }
-
-    uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        return n_embd_head_v * n_head_kv;
-    }
-
-    uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
-        // corresponds to Mamba's conv_states size or RWKV's token_shift states size
-        if (wkv_head_size != 0) {
-            // for RWKV models
-            return 2 * n_embd;
-        }
-
-        // TODO: maybe support other convolution strides than 1
-        // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
-        return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
-    }
-
-    uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
-        if (wkv_head_size != 0) {
-            // corresponds to RWKV's wkv_states size
-            return n_embd * wkv_head_size;
-        }
-
-        // corresponds to Mamba's ssm_states size
-        return ssm_d_state * ssm_d_inner;
-    }
-};
-
-static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
-
-struct llama_cparams {
-    uint32_t n_ctx;           // context size used during inference
-    uint32_t n_batch;
-    uint32_t n_ubatch;
-    uint32_t n_seq_max;
-    int      n_threads;       // number of threads to use for generation
-    int      n_threads_batch; // number of threads to use for batch processing
-
-    float rope_freq_base;
-    float rope_freq_scale;
-
-    uint32_t n_ctx_orig_yarn;
-    // These hyperparameters are not exposed in GGUF, because all
-    // existing YaRN models use the same values for them.
-    float yarn_ext_factor;
-    float yarn_attn_factor;
-    float yarn_beta_fast;
-    float yarn_beta_slow;
-    float defrag_thold;
-
-    bool embeddings;
-    bool causal_attn;
-    bool offload_kqv;
-    bool flash_attn;
-    bool no_perf;
-
-    enum llama_pooling_type pooling_type;
-
-    ggml_backend_sched_eval_callback cb_eval;
-    void * cb_eval_user_data;
-};
-
-struct llama_layer_posnet {
-    // resnet
-    struct ggml_tensor * norm1   = nullptr;
-    struct ggml_tensor * norm1_b = nullptr;
-
-    struct ggml_tensor * conv1   = nullptr;
-    struct ggml_tensor * conv1_b = nullptr;
-
-    struct ggml_tensor * norm2   = nullptr;
-    struct ggml_tensor * norm2_b = nullptr;
-
-    struct ggml_tensor * conv2   = nullptr;
-    struct ggml_tensor * conv2_b = nullptr;
-
-    // attention
-    struct ggml_tensor * attn_norm   = nullptr;
-    struct ggml_tensor * attn_norm_b = nullptr;
-
-    struct ggml_tensor * attn_q   = nullptr;
-    struct ggml_tensor * attn_q_b = nullptr;
-
-    struct ggml_tensor * attn_k   = nullptr;
-    struct ggml_tensor * attn_k_b = nullptr;
-
-    struct ggml_tensor * attn_v   = nullptr;
-    struct ggml_tensor * attn_v_b = nullptr;
-
-    struct ggml_tensor * attn_o   = nullptr;
-    struct ggml_tensor * attn_o_b = nullptr;
-
-    // normalize
-    struct ggml_tensor * norm   = nullptr;
-    struct ggml_tensor * norm_b = nullptr;
-};
-
-struct llama_layer_convnext {
-    struct ggml_tensor * dw   = nullptr;
-    struct ggml_tensor * dw_b = nullptr;
-
-    struct ggml_tensor * norm   = nullptr;
-    struct ggml_tensor * norm_b = nullptr;
-
-    struct ggml_tensor * pw1   = nullptr;
-    struct ggml_tensor * pw1_b = nullptr;
-
-    struct ggml_tensor * pw2   = nullptr;
-    struct ggml_tensor * pw2_b = nullptr;
-
-    struct ggml_tensor * gamma = nullptr;
-};
-
-struct llama_layer {
-    // normalization
-    struct ggml_tensor * attn_norm       = nullptr;
-    struct ggml_tensor * attn_norm_b     = nullptr;
-    struct ggml_tensor * attn_norm_2     = nullptr;
-    struct ggml_tensor * attn_norm_2_b   = nullptr;
-    struct ggml_tensor * attn_q_norm     = nullptr;
-    struct ggml_tensor * attn_q_norm_b   = nullptr;
-    struct ggml_tensor * attn_k_norm     = nullptr;
-    struct ggml_tensor * attn_k_norm_b   = nullptr;
-    struct ggml_tensor * attn_out_norm   = nullptr;
-    struct ggml_tensor * attn_out_norm_b = nullptr;
-    struct ggml_tensor * attn_q_a_norm   = nullptr;
-    struct ggml_tensor * attn_kv_a_norm  = nullptr;
-    struct ggml_tensor * attn_sub_norm   = nullptr;
-    struct ggml_tensor * attn_post_norm  = nullptr;
-    struct ggml_tensor * ffn_sub_norm    = nullptr;
-    struct ggml_tensor * attn_norm_cross = nullptr;
-    struct ggml_tensor * attn_norm_enc   = nullptr;
-
-    // attention
-    struct ggml_tensor * wq        = nullptr;
-    struct ggml_tensor * wk        = nullptr;
-    struct ggml_tensor * wv        = nullptr;
-    struct ggml_tensor * wo        = nullptr;
-    struct ggml_tensor * wqkv      = nullptr;
-    struct ggml_tensor * wq_a      = nullptr;
-    struct ggml_tensor * wq_b      = nullptr;
-    struct ggml_tensor * wkv_a_mqa = nullptr;
-    struct ggml_tensor * wkv_b     = nullptr;
-    struct ggml_tensor * wq_cross  = nullptr;
-    struct ggml_tensor * wk_cross  = nullptr;
-    struct ggml_tensor * wv_cross  = nullptr;
-    struct ggml_tensor * wo_cross  = nullptr;
-    struct ggml_tensor * wq_enc    = nullptr;
-    struct ggml_tensor * wk_enc    = nullptr;
-    struct ggml_tensor * wv_enc    = nullptr;
-    struct ggml_tensor * wo_enc    = nullptr;
-
-    // attention bias
-    struct ggml_tensor * bq   = nullptr;
-    struct ggml_tensor * bk   = nullptr;
-    struct ggml_tensor * bv   = nullptr;
-    struct ggml_tensor * bo   = nullptr;
-    struct ggml_tensor * bqkv = nullptr;
-
-    // relative position bias
-    struct ggml_tensor * attn_rel_b       = nullptr;
-    struct ggml_tensor * attn_rel_b_enc   = nullptr;
-    struct ggml_tensor * attn_rel_b_cross = nullptr;
-
-    // normalization
-    struct ggml_tensor * ffn_norm         = nullptr;
-    struct ggml_tensor * ffn_norm_b       = nullptr;
-    struct ggml_tensor * ffn_post_norm    = nullptr;
-    struct ggml_tensor * layer_out_norm   = nullptr;
-    struct ggml_tensor * layer_out_norm_b = nullptr;
-    struct ggml_tensor * ffn_norm_exps    = nullptr;
-    struct ggml_tensor * ffn_norm_enc     = nullptr;
-
-    // ff
-    struct ggml_tensor * ffn_gate     = nullptr; // w1
-    struct ggml_tensor * ffn_down     = nullptr; // w2
-    struct ggml_tensor * ffn_up       = nullptr; // w3
-    struct ggml_tensor * ffn_gate_enc = nullptr;
-    struct ggml_tensor * ffn_down_enc = nullptr;
-    struct ggml_tensor * ffn_up_enc   = nullptr;
-
-    // ff MoE
-    struct ggml_tensor * ffn_gate_inp  = nullptr;
-    struct ggml_tensor * ffn_gate_exps = nullptr;
-    struct ggml_tensor * ffn_down_exps = nullptr;
-    struct ggml_tensor * ffn_up_exps   = nullptr;
-
-    // ff shared expert (shexp)
-    struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
-    struct ggml_tensor * ffn_gate_shexp     = nullptr;
-    struct ggml_tensor * ffn_down_shexp     = nullptr;
-    struct ggml_tensor * ffn_up_shexp       = nullptr;
-
-    // ff bias
-    struct ggml_tensor * ffn_gate_b = nullptr;
-    struct ggml_tensor * ffn_down_b = nullptr; // b2
-    struct ggml_tensor * ffn_up_b   = nullptr; // b3
-    struct ggml_tensor * ffn_act    = nullptr;
-
-    // mamba proj
-    struct ggml_tensor * ssm_in  = nullptr;
-    struct ggml_tensor * ssm_x   = nullptr;
-    struct ggml_tensor * ssm_dt  = nullptr;
-    struct ggml_tensor * ssm_out = nullptr;
-
-    // mamba
-    struct ggml_tensor * ssm_conv1d = nullptr;
-    struct ggml_tensor * ssm_a      = nullptr;
-    struct ggml_tensor * ssm_d      = nullptr;
-
-    // mamba bias
-    struct ggml_tensor * ssm_conv1d_b = nullptr;
-    struct ggml_tensor * ssm_dt_b     = nullptr;
-
-    // rwkv
-    struct ggml_tensor * time_mix_w1         = nullptr;
-    struct ggml_tensor * time_mix_w2         = nullptr;
-    struct ggml_tensor * time_mix_lerp_x     = nullptr;
-    struct ggml_tensor * time_mix_lerp_w     = nullptr;
-    struct ggml_tensor * time_mix_lerp_k     = nullptr;
-    struct ggml_tensor * time_mix_lerp_v     = nullptr;
-    struct ggml_tensor * time_mix_lerp_r     = nullptr;
-    struct ggml_tensor * time_mix_lerp_g     = nullptr;
-
-    struct ggml_tensor * time_mix_first      = nullptr;
-    struct ggml_tensor * time_mix_decay      = nullptr;
-    struct ggml_tensor * time_mix_decay_w1   = nullptr;
-    struct ggml_tensor * time_mix_decay_w2   = nullptr;
-    struct ggml_tensor * time_mix_key        = nullptr;
-    struct ggml_tensor * time_mix_value      = nullptr;
-    struct ggml_tensor * time_mix_receptance = nullptr;
-    struct ggml_tensor * time_mix_gate       = nullptr;
-
-    struct ggml_tensor * time_mix_ln     = nullptr;
-    struct ggml_tensor * time_mix_ln_b   = nullptr;
-    struct ggml_tensor * time_mix_output = nullptr;
-
-    struct ggml_tensor * channel_mix_lerp_k = nullptr;
-    struct ggml_tensor * channel_mix_lerp_r = nullptr;
-
-    struct ggml_tensor * channel_mix_key        = nullptr;
-    struct ggml_tensor * channel_mix_receptance = nullptr;
-    struct ggml_tensor * channel_mix_value      = nullptr;
-
-    // long rope factors
-    struct ggml_tensor * rope_long  = nullptr;
-    struct ggml_tensor * rope_short = nullptr;
-    struct ggml_tensor * rope_freqs = nullptr;
-
-    // bitnet scale
-    struct ggml_tensor * wq_scale       = nullptr;
-    struct ggml_tensor * wk_scale       = nullptr;
-    struct ggml_tensor * wv_scale       = nullptr;
-    struct ggml_tensor * wo_scale       = nullptr;
-    struct ggml_tensor * ffn_gate_scale = nullptr;
-    struct ggml_tensor * ffn_up_scale   = nullptr;
-    struct ggml_tensor * ffn_down_scale = nullptr;
-
-    struct llama_layer_posnet posnet;
-
-    struct llama_layer_convnext convnext;
-};
-
-// very similar to llama_batch,
-// but has more metadata about sequences
-struct llama_ubatch {
-    bool equal_seqs;
-    // TODO: whole_seqs for embeddings?
-
-    uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
-    uint32_t n_seq_tokens; // tokens per sequence
-    uint32_t n_seqs;
-
-    llama_token  *  token;    // [n_tokens]
-    float        *  embd;     // [n_embd, n_tokens]
-    llama_pos    *  pos;      // [n_tokens]
-    int32_t      *  n_seq_id; // [n_seqs]
-    llama_seq_id ** seq_id;   // [n_seqs]
-    int8_t       *  output;   // [n_tokens]
-};
-
-struct llama_kv_cell {
-    llama_pos pos   = -1;
-    llama_pos delta = 0;
-    int32_t   src   = -1; // used by recurrent state models to copy states
-    int32_t   tail  = -1;
-
-    std::set seq_id;
-
-    bool has_seq_id(const llama_seq_id & id) const {
-        return seq_id.find(id) != seq_id.end();
-    }
-
-    bool is_empty() const {
-        return seq_id.empty();
-    }
-
-    bool is_same_seq(const llama_kv_cell & other) const {
-        return seq_id == other.seq_id;
-    }
-};
-
-// ring-buffer of cached KV data
-struct llama_kv_cache {
-    bool has_shift = false;
-    bool do_defrag = false;
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
-    bool v_trans   = true;  // the value tensor is transposed
-
-    // Note: The value of head isn't only used to optimize searching
-    // for a free KV slot. llama_decode_internal also uses it, so it
-    // cannot be freely changed after a slot has been allocated.
-    uint32_t head = 0;
-    uint32_t size = 0;
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
-
-    // computed before each graph build
-    uint32_t n = 0;
-
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
-
-    std::vector cells;
-
-    std::vector k_l; // per layer
-    std::vector v_l;
-
-    std::vector ctxs;
-    std::vector bufs;
-
-    size_t total_size() {
-        size_t size = 0;
-        for (auto & buf : bufs) {
-            size += ggml_backend_buffer_get_size(buf.get());
-        }
-        return size;
-    }
-};
-
-struct llama_control_vector {
-    std::vector tensors; // per layer
-    std::vector ctxs;
-    std::vector bufs;
-
-    int32_t layer_start = -1;
-    int32_t layer_end   = -1;
-
-    struct ggml_tensor * tensor_for(int il) const {
-        if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
-            return nullptr;
-        }
-        return tensors[il];
-    }
-
-    struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const {
-        ggml_tensor * layer_dir = tensor_for(il);
-        if (layer_dir != nullptr) {
-            cur = ggml_add(ctx, cur, layer_dir);
-        }
-        return cur;
-    }
-};
-
-struct llama_model {
-    e_model     type  = MODEL_UNKNOWN;
-    llm_arch    arch  = LLM_ARCH_UNKNOWN;
-    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
-
-    std::string name = "n/a";
-
-    llama_hparams hparams = {};
-    llama_vocab   vocab;
-
-    struct ggml_tensor * tok_embd = nullptr;
-    struct ggml_tensor * type_embd = nullptr;
-    struct ggml_tensor * pos_embd = nullptr;
-    struct ggml_tensor * tok_norm = nullptr;
-    struct ggml_tensor * tok_norm_b = nullptr;
-
-    struct ggml_tensor * output_norm = nullptr;
-    struct ggml_tensor * output_norm_b = nullptr;
-    struct ggml_tensor * output = nullptr;
-    struct ggml_tensor * output_b = nullptr;
-    struct ggml_tensor * output_norm_enc = nullptr;
-
-    // classifier
-    struct ggml_tensor * cls = nullptr;
-    struct ggml_tensor * cls_b = nullptr;
-    struct ggml_tensor * cls_out   = nullptr;
-    struct ggml_tensor * cls_out_b = nullptr;
-
-    struct ggml_tensor * conv1d = nullptr;
-    struct ggml_tensor * conv1d_b = nullptr;
-
-    std::vector layers;
-
-    // gguf metadata
-    std::unordered_map gguf_kv;
-
-    llama_split_mode split_mode;
-    int main_gpu;
-    int n_gpu_layers;
-
-    std::vector rpc_servers;
-
-    // list of devices used in this model
-    std::vector devices;
-
-
-    // lists of buffer types used for each layer
-    using buft_list_t = std::vector>;
-    buft_list_t cpu_buft_list;
-    std::map gpu_buft_list;
-
-    struct layer_dev {
-        ggml_backend_dev_t dev;
-        buft_list_t * buft_list;
-    };
-    layer_dev dev_input = {};
-    layer_dev dev_output = {};
-    std::vector dev_layer;
-
-    // contexts where the model tensors metadata is stored
-    std::vector ctxs;
-
-    // the model memory buffers for the tensor data
-    std::vector bufs;
-
-    // model memory mapped files
-    llama_mmaps mappings;
-
-    // objects representing data potentially being locked in memory
-    llama_mlocks mlock_bufs;
-    llama_mlocks mlock_mmaps;
-
-    // for quantize-stats only
-    std::vector> tensors_by_name;
-
-    int64_t t_load_us  = 0;
-    int64_t t_start_us = 0;
-
-    // total number of parameters in the model
-    uint64_t n_elements = 0;
-
-    // total size of all the tensors in the model in bytes
-    size_t  n_bytes     = 0;
-
-    // keep track of loaded lora adapters
-    std::set lora_adapters;
-
-    ~llama_model() {
-       while (!lora_adapters.empty()) {
-            llama_lora_adapter_free(*lora_adapters.begin());
-        }
-    }
-};
-
-struct llama_sbatch_seq {
-    int32_t n_seq_id;
-    llama_seq_id * seq_id;
-    size_t offset;
-    size_t length;
-};
-
-// sequence-length-aware batch splitting
-struct llama_sbatch {
-    // tokens left in this batch
-    size_t n_tokens;
-
-    size_t n_embd;
-
-    bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
-    // sorted indices into the batch
-    std::vector ids;
-    // batch indices of the output
-    std::vector out_ids;
-    std::vector seq;
-
-    const llama_batch * batch = nullptr;
-
-    // buffers for the ubatch
-    std::vector    ubatch_token;
-    std::vector          ubatch_embd;
-    std::vector      ubatch_pos;
-    std::vector        ubatch_n_seq_id;
-    std::vector ubatch_seq_id;
-    std::vector         ubatch_output;
-
-    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
-        // clear empty sequences
-        // the previous ubatch is assumed to be gone,
-        // so nothing should refer to values in these sequences anymore.
-        for (size_t i = seq.size(); i-- > 0;) {
-            if (seq[i].length == 0) {
-                seq.pop_back();
-            } else {
-                break;
-            }
-        }
-        ubatch_token.resize(!has_embd ? n_ubatch : 0);
-        ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
-        ubatch_pos.resize(n_ubatch);
-        ubatch_n_seq_id.resize(n_ubatch);
-        ubatch_seq_id.resize(n_ubatch);
-        ubatch_output.resize(n_ubatch);
-        llama_ubatch ubatch = {
-            /*equal_seqs   =*/ true,
-            /*n_tokens     =*/ 0,
-            /*n_seq_tokens =*/ 0,
-            /*n_seqs       =*/ 0,
-            /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
-            /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
-            /*pos          =*/ ubatch_pos.data(),
-            /*n_seq_id     =*/ ubatch_n_seq_id.data(),
-            /*seq_id       =*/ ubatch_seq_id.data(),
-            /*output       =*/ ubatch_output.data(),
-        };
-        return ubatch;
-    }
-
-    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
-        GGML_ASSERT(batch != nullptr);
-        GGML_ASSERT(length <= seq.length);
-        // Can only add sequences of equal lengths to a batch,
-        // otherwise it isn't clear to which sequence a token belongs
-        GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
-        GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
-        // NOTE: loops are separated for cache-friendliness
-        if (batch->token) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
-                }
-            } else {
-                // simple split
-                ubatch.token = batch->token + seq.offset;
-            }
-        } else {
-            ubatch.token = nullptr;
-        }
-        if (batch->embd) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    memcpy(
-                        ubatch.embd + n_embd * (ubatch.n_tokens + i),
-                        batch->embd + n_embd * ids[seq.offset + i],
-                        n_embd * sizeof(float)
-                    );
-                }
-            } else {
-                // simple split
-                ubatch.embd = batch->embd + (n_embd * seq.offset);
-            }
-        } else {
-            ubatch.embd = nullptr;
-        }
-        if (ubatch.equal_seqs) {
-            for (size_t i = 0; i < length; ++i) {
-                ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
-            }
-        } else {
-            // simple split
-            ubatch.pos = batch->pos + seq.offset;
-        }
-        if (ubatch.equal_seqs) {
-            ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
-            if (seq.seq_id) {
-                ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
-            }
-        } else {
-            // simple split
-            if (batch->n_seq_id) {
-                ubatch.n_seq_id = batch->n_seq_id + seq.offset;
-            } else {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
-                }
-            }
-            if (batch->seq_id) {
-                ubatch.seq_id = batch->seq_id + seq.offset;
-            }
-        }
-        if (logits_all) {
-            for (size_t i = 0; i < length; ++i) {
-                ubatch.output[ubatch.n_tokens + i] = 1;
-                out_ids.push_back(ids[seq.offset + i]);
-            }
-        } else if (batch->logits) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    size_t id = ids[seq.offset + i];
-                    int8_t is_output = batch->logits[id];
-                    ubatch.output[ubatch.n_tokens + i] = is_output;
-                    if (is_output) { out_ids.push_back(id); }
-                }
-            } else {
-                // simple split
-                ubatch.output = batch->logits + seq.offset;
-                for (size_t i = 0; i < length; ++i) {
-                    if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
-                }
-            }
-        } else {
-            // only get last output
-            for (size_t i = 0; i < length; ++i) {
-                size_t id = ids[seq.offset + i];
-                int8_t is_last = id == ids.size() - 1;
-                ubatch.output[ubatch.n_tokens + i] = is_last;
-                if (is_last) { out_ids.push_back(id); }
-            }
-        }
-        if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
-            ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
-        }
-        ubatch.n_tokens += length;
-        ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
-        seq.offset += length;
-        seq.length -= length;
-        n_tokens -= length;
-        GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
-    }
-
-    // simple split, unknown number of sequences of unequal lengths
-    llama_ubatch split_simple(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        ubatch.equal_seqs = false;
-        if (!seq.empty()) {
-            llama_sbatch_seq & s = seq[0];
-            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-            GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
-            add_seq_to_ubatch(ubatch, s, length);
-        }
-        return ubatch;
-    }
-
-    // make batches of equal-length sequences
-    llama_ubatch split_equal(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        if (!seq.empty()) {
-            size_t length = 0;
-            size_t n_tokens_in_ubatch = 0;
-            GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
-            // smallest first, because it's easier to split this way;
-            // starting from the end to pop in constant time.
-            for (size_t i = seq.size(); i-- > 0;) {
-                llama_sbatch_seq & s = seq[i];
-                GGML_ASSERT(s.length > 0);
-                if (length == 0) {
-                    length = s.length < n_ubatch ? s.length : n_ubatch;
-                }
-                add_seq_to_ubatch(ubatch, s, length);
-                n_tokens_in_ubatch += length;
-                // shared prompts can't be mixed with any of their sequences,
-                // so it's safer to compute them in their own ubatch
-                if (s.n_seq_id > 1) { break; }
-                // stop when there isn't enough space for another sequence
-                if (length + n_tokens_in_ubatch > n_ubatch) { break; }
-            }
-        }
-        return ubatch;
-    }
-
-    // sequence-wise split
-    llama_ubatch split_seq(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        if (!seq.empty()) {
-            llama_sbatch_seq & s = seq[seq.size() - 1];
-            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-            GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
-            add_seq_to_ubatch(ubatch, s, length);
-        }
-        return ubatch;
-    }
-
-    void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
-        GGML_ASSERT(batch.n_tokens >= 0);
-        this->batch = &batch;
-        this->n_embd = n_embd;
-        this->logits_all = logits_all;
-
-        n_tokens = batch.n_tokens;
-        ids.resize(n_tokens);
-        out_ids.clear();
-        // TODO: reserve out_ids and seq
-
-        for (size_t i = 0; i < n_tokens; ++i) {
-            ids[i] = i;
-        }
-        if (simple_split) {
-            seq.resize(1);
-            llama_sbatch_seq & s = seq[0];
-            s.n_seq_id = 0;
-            s.seq_id = nullptr;
-            s.offset = 0;
-            s.length = n_tokens;
-            return;
-        }
-        std::sort(ids.begin(), ids.end(),
-            [&batch](size_t a, size_t b) {
-                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
-                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
-                // sort by seq_id, then by pos
-                if (n_seq_a == n_seq_b) {
-                    if (batch.seq_id) {
-                        for (int32_t i = 0; i < n_seq_a; ++i) {
-                            llama_seq_id seq_id_a = batch.seq_id[a][i];
-                            llama_seq_id seq_id_b = batch.seq_id[b][i];
-                            // smaller seq_ids go first
-                            if (seq_id_a != seq_id_b) {
-                                return seq_id_a < seq_id_b;
-                            }
-                        }
-                    }
-                    // when all else is equal, sort by pos
-                    if (batch.pos) {
-                        return batch.pos[a] < batch.pos[b];
-                    }
-                    // no pos, sort by id
-                    return a < b;
-                }
-                // shared prompts go first
-                return n_seq_a > n_seq_b;
-            }
-        );
-        // init seq
-        llama_sbatch_seq * last_seq = nullptr;
-
-        for (size_t i = 0; i < n_tokens; ++i) {
-            const size_t bi = ids[i];
-            const int32_t n_seqs = batch.n_seq_id[bi];
-            llama_seq_id * seq_ids = batch.seq_id[bi];
-            if (last_seq != nullptr) {
-                bool same = n_seqs == last_seq->n_seq_id;
-                for (int32_t j = 0; same && j < n_seqs; ++j) {
-                    if (seq_ids[j] != last_seq->seq_id[j]) {
-                        same = false;
-                    }
-                }
-                if (same) {
-                    last_seq->length += 1;
-                    continue;
-                }
-            }
-            llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
-            seq.push_back(new_seq);
-            last_seq = &seq.back();
-        }
-        // keep shared prompts first at the end, then sort by length descending.
-        std::sort(seq.begin(), seq.end(),
-            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
-                if (a.n_seq_id == b.n_seq_id) {
-                    return a.length > b.length;
-                }
-                return a.n_seq_id < b.n_seq_id;
-            }
-        );
-    }
-};
-
-struct llama_context {
-    llama_context(const llama_model & model)
-        : model(model)
-        , t_start_us(model.t_start_us)
-        , t_load_us(model.t_load_us) {}
-
-    const struct llama_model & model;
-
-    struct llama_cparams        cparams;
-    struct llama_sbatch         sbatch;
-    struct llama_kv_cache       kv_self;
-    struct llama_control_vector cvec;
-
-    std::unordered_map lora_adapters;
-
-    std::vector backends;
-    std::vector> set_n_threads_fns;
-
-    ggml_backend_t backend_cpu = nullptr;
-
-    ggml_threadpool_t threadpool       = nullptr;
-    ggml_threadpool_t threadpool_batch = nullptr;
-
-    bool has_evaluated_once = false;
-
-    mutable int64_t t_start_us;
-    mutable int64_t t_load_us;
-    mutable int64_t t_p_eval_us = 0;
-    mutable int64_t t_eval_us   = 0;
-
-    mutable int64_t t_compute_start_us = 0;
-    mutable int64_t n_queued_tokens = 0;
-
-    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    mutable int32_t n_eval   = 0; // number of eval calls
-
-    // host buffer for the model output (logits and embeddings)
-    ggml_backend_buffer_ptr buf_output;
-
-    // decode output (2-dimensional array: [n_outputs][n_vocab])
-    size_t  logits_size = 0; // capacity (of floats) for logits
-    float * logits      = nullptr;
-
-    std::vector output_ids; // map batch token positions to ids of the logits and embd buffers
-    size_t  output_size = 0; // capacity (of tokens positions) for the output buffers
-    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
-
-    bool logits_all = false;
-
-    // embeddings output (2-dimensional array: [n_outputs][n_embd])
-    // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
-    size_t  embd_size = 0; // capacity (of floats) for embeddings
-    float * embd      = nullptr;
-
-    // sequence embeddings output (map of [n_embd] vectors)
-    // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
-    std::map> embd_seq;
-
-    // whether we are computing encoder output or decoder output
-    bool is_encoding = false;
-
-    // TODO: find a better way to accommodate mutli-dimension position encoding methods
-    // number of position id each token get, 1 for each token in most cases.
-    // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
-    int n_pos_per_token = 1;
-
-    // output of the encoder part of the encoder-decoder models
-    std::vector embd_enc;
-    std::vector> seq_ids_enc;
-
-    // memory buffers used to evaluate the model
-    std::vector buf_compute_meta;
-    ggml_backend_sched_ptr sched;
-
-    ggml_abort_callback abort_callback      = nullptr;
-    void *              abort_callback_data = nullptr;
-
-    // input tensors
-    struct ggml_tensor * inp_tokens;      // I32 [n_batch]
-    struct ggml_tensor * inp_embd;        // F32 [n_embd, n_batch]
-    struct ggml_tensor * inp_pos;         // I32 [n_batch]
-    struct ggml_tensor * inp_out_ids;     // I32 [n_outputs]
-    struct ggml_tensor * inp_KQ_mask;     // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_K_shift;     // I32 [kv_size]
-    struct ggml_tensor * inp_mean;        // F32 [n_batch, n_batch]
-    struct ggml_tensor * inp_cls;         // I32 [n_batch]
-    struct ggml_tensor * inp_s_copy;      // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;      // F32 [1, n_kv]
-    struct ggml_tensor * inp_s_seq;       // I32 [n_kv, n_batch]
-    struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
-    struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
-    struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
-};
-
-struct llama_lora_weight {
-    struct ggml_tensor * a = nullptr;
-    struct ggml_tensor * b = nullptr;
-    llama_lora_weight() = default;
-    llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
-};
-
-struct llama_lora_adapter {
-    struct llama_model * base_model;
-    // map tensor name to lora_a_b
-    std::unordered_map ab_map;
-    std::vector ctxs;
-    std::vector bufs;
-
-    float alpha;
-
-    llama_lora_adapter(struct llama_model * base_model): base_model(base_model) {
-        base_model->lora_adapters.insert(this);
-    }
-
-    llama_lora_weight * get_weight(struct ggml_tensor * w) {
-        std::string name(w->name);
-        auto pos = ab_map.find(name);
-        if (ab_map.find(name) != ab_map.end()) {
-            return &pos->second;
-        }
-        return nullptr;
-    }
-
-    ~llama_lora_adapter() {
-        auto pos = base_model->lora_adapters.find(this);
-        if (pos != base_model->lora_adapters.end()) {
-            base_model->lora_adapters.erase(pos);
-        }
-    }
-};
-
 static int llama_get_device_count(const llama_model & model) {
     return (int) model.devices.size();
 }
 
-static struct ggml_tensor * llama_get_model_tensor(const struct llama_model * model, const char * name) {
-    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
-            [name](const std::pair & it) {
-                return it.first == name;
-            });
-    if (it == model->tensors_by_name.end()) {
-        return nullptr;
-    }
-    return it->second;
-}
-
-template
-static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    ggml_context_ptr ctx { ggml_init(params) };
-    if (!ctx) {
-        throw std::runtime_error(format("failed to create ggml context"));
-    }
-
-    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
-    ggml_tensor * op_tensor = fn(ctx.get());
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op_tensor->src[i] != nullptr) {
-            assert(op_tensor->src[i]->buffer == nullptr);
-            op_tensor->src[i]->buffer = buf.get();
-        }
-    }
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-
-    return op_supported;
-}
-
-template
-static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (buft_supported(cur_buft, cur_dev, fn)) {
-            return cur_buft;
-        }
-    }
-    throw std::runtime_error(format("no suitable buffer type found"));
-}
-
-//
-// kv cache helpers
-//
-
-static bool llama_kv_cache_init(
-             struct llama_kv_cache & cache,
-               const llama_context * ctx,
-                         ggml_type   type_k,
-                         ggml_type   type_v,
-                          uint32_t   kv_size,
-                              bool   offload) {
-    const llama_model & model = ctx->model;
-    const llama_cparams & cparams = ctx->cparams;
-
-    const struct llama_hparams & hparams = model.hparams;
-
-    const int32_t n_layer = hparams.n_layer;
-
-    LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
-
-    cache.has_shift = false;
-
-    cache.recurrent = llama_model_is_recurrent(&model);
-    cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
-
-    cache.head = 0;
-    cache.size = kv_size;
-    cache.used = 0;
-
-    cache.type_k = type_k;
-    cache.type_v = type_v;
-
-    cache.cells.clear();
-    cache.cells.resize(kv_size);
-
-    // create a context for each buffer type
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            struct ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                return nullptr;
-            }
-            ctx_map[buft] = ctx;
-            cache.ctxs.emplace_back(ctx);
-            return ctx;
-        }
-        return it->second;
-    };
-
-    cache.k_l.reserve(n_layer);
-    cache.v_l.reserve(n_layer);
-
-    for (int i = 0; i < n_layer; i++) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
-
-        LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
-
-        ggml_backend_buffer_type_t buft;
-        if (offload) {
-            auto * dev = model.dev_layer.at(i).dev;
-            buft = ggml_backend_dev_buffer_type(dev);
-        } else {
-            buft = ggml_backend_cpu_buffer_type();
-        }
-        ggml_context * ctx = ctx_for_buft(buft);
-
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
-            return false;
-        }
-
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
-        ggml_format_name(k, "cache_k_l%d", i);
-        ggml_format_name(v, "cache_v_l%d", i);
-        cache.k_l.push_back(k);
-        cache.v_l.push_back(v);
-    }
-
-    // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
-            return false;
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-        cache.bufs.emplace_back(buf);
-    }
-
-    return true;
-}
-
-// a structure holds information about the slot found in llama_kv_cache_find_slot
-struct llama_kv_cache_slot_info {
-    std::pair boundaries; // slot boundaries [begin, end)
-    bool found = false;                       // the slot was found
-
-    explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
-    llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
-
-    operator bool() const { return found; }
-};
-static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
-
-// find an empty slot of size "n_tokens" in the cache
-// updates the cache head
-// returns a structure holding information about the slot found
-// Note: On success, it's important that cache.head points
-// to the first cell of the slot.
-static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
-           struct llama_kv_cache & cache,
-       const struct llama_ubatch & batch) {
-    const uint32_t n_tokens = batch.n_tokens;
-    const uint32_t n_seqs   = batch.n_seqs;
-    const uint32_t n_seq_tokens = batch.n_seq_tokens;
-
-    if (cache.recurrent) {
-        // For recurrent state architectures (like Mamba or RWKV),
-        // each cache cell can store the state for a whole sequence.
-        // A slot should be always be contiguous.
-
-        // can only process batches with an equal number of new tokens in each sequence
-        GGML_ASSERT(batch.equal_seqs);
-
-        int32_t min = cache.size - 1;
-        int32_t max = 0;
-
-        // everything should fit if all seq_ids are smaller than the max
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const uint32_t n_seq_id = batch.n_seq_id[s];
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
-
-                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the cache instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
-                    return llama_kv_cache_slot_info_failed;
-                }
-                if (j > 0) {
-                    llama_kv_cell & seq = cache.cells[seq_id];
-                    if (seq.tail >= 0) {
-                        llama_kv_cell & cell = cache.cells[seq.tail];
-                        // clear cells from seq_ids that become shared
-                        // (should not normally happen, but let's handle it anyway)
-                        cell.seq_id.erase(seq_id);
-                        seq.tail = -1;
-                        if (cell.seq_id.empty()) {
-                            cell.pos = -1;
-                            cell.src = -1;
-                            cache.used -= 1;
-                        }
-                    }
-                }
-            }
-        }
-
-#ifndef NDEBUG
-        {
-            std::vector tails_verif;
-            tails_verif.assign(cache.size, -1);
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                llama_kv_cell & cell = cache.cells[i];
-                for (llama_seq_id seq_id : cell.seq_id) {
-                    if (tails_verif[seq_id] != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
-                    }
-                    tails_verif[seq_id] = i;
-                }
-            }
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                if (tails_verif[i] != cache.cells[i].tail) {
-                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
-                }
-            }
-        }
-#endif
-
-        // find next empty cell
-        uint32_t next_empty_cell = cache.head;
-
-        for (uint32_t i = 0; i < cache.size; ++i) {
-            if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-            llama_kv_cell & cell = cache.cells[next_empty_cell];
-            if (cell.is_empty()) { break; }
-            next_empty_cell += 1;
-        }
-
-        // find usable cell range
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-            llama_kv_cell & seq_meta = cache.cells[seq_id];
-            bool has_cell = false;
-            if (seq_meta.tail >= 0) {
-                llama_kv_cell & cell = cache.cells[seq_meta.tail];
-                GGML_ASSERT(cell.has_seq_id(seq_id));
-                // does this seq_id "own" the cell?
-                if (cell.seq_id.size() == 1) { has_cell = true; }
-            }
-            if (!has_cell) {
-                llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
-                GGML_ASSERT(empty_cell.is_empty());
-                // copy old tail into the empty cell
-                if (seq_meta.tail >= 0) {
-                    llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
-                    empty_cell.pos = orig_cell.pos;
-                    empty_cell.src = orig_cell.src;
-                    orig_cell.seq_id.erase(seq_id);
-                    empty_cell.seq_id.insert(seq_id); // will be overwritten
-                }
-                seq_meta.tail = next_empty_cell;
-                // find next empty cell
-                if (s + 1 < n_seqs) {
-                    next_empty_cell += 1;
-                    for (uint32_t i = 0; i < cache.size; ++i) {
-                        if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-                        llama_kv_cell & cell = cache.cells[next_empty_cell];
-                        if (cell.is_empty()) { break; }
-                        next_empty_cell += 1;
-                    }
-                }
-            }
-            if (min > seq_meta.tail) { min = seq_meta.tail; }
-            if (max < seq_meta.tail) { max = seq_meta.tail; }
-        }
-
-        // gather and re-order
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            int32_t dst_id = s + min;
-            int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
-            if (dst_id != src_id) {
-                llama_kv_cell & dst_cell = cache.cells[dst_id];
-                llama_kv_cell & src_cell = cache.cells[src_id];
-
-                std::swap(dst_cell.pos, src_cell.pos);
-                std::swap(dst_cell.src, src_cell.src);
-                std::swap(dst_cell.seq_id, src_cell.seq_id);
-
-                // swap tails (assuming they NEVER overlap)
-                for (const llama_seq_id seq_id : src_cell.seq_id) {
-                    cache.cells[seq_id].tail = src_id;
-                }
-                for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                    cache.cells[seq_id].tail = dst_id;
-                }
-            }
-        }
-
-        // update the pos of the used seqs
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-            int32_t cell_id = s + min;
-            llama_kv_cell & cell = cache.cells[cell_id];
-
-            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
-                // What should happen when the pos backtracks or skips a value?
-                // Clearing the state mid-batch would require special-casing which isn't done.
-                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                    __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
-            }
-            cell.pos = last_pos;
-            cell.seq_id.clear();
-            for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
-                cell.seq_id.insert(seq_id);
-                cache.cells[seq_id].tail = cell_id;
-            }
-        }
-
-        // allow getting the range of used cells, from head to head + n
-        cache.head = min;
-        cache.n    = max - min + 1;
-        cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
-            [](const llama_kv_cell& cell){ return !cell.is_empty(); });
-
-        // sanity check
-        return llama_kv_cache_slot_info(cache.n >= n_seqs);
-    }
-    // otherwise, one cell per token.
-
-    if (n_tokens > cache.size) {
-        LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
-        return llama_kv_cache_slot_info_failed;
-    }
-
-    uint32_t n_tested = 0;
-
-    while (true) {
-        if (cache.head + n_tokens > cache.size) {
-            n_tested += cache.size - cache.head;
-            cache.head = 0;
-            continue;
-        }
-
-        bool found = true;
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            if (cache.cells[cache.head + i].pos >= 0) {
-                found = false;
-                cache.head += i + 1;
-                n_tested   += i + 1;
-                break;
-            }
-        }
-
-        if (found) {
-            break;
-        }
-
-        if (n_tested >= cache.size) {
-            //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return llama_kv_cache_slot_info_failed;
-        }
-    }
-
-    for (uint32_t s = 0; s < n_seqs; s++) {
-        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
-            uint32_t k = s*n_seq_tokens + i;
-            cache.cells[cache.head + k].pos = batch.pos[k];
-
-            for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
-                cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
-            }
-        }
-    }
-
-    cache.used += n_tokens;
-
-    return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
-}
-
-// find how many cells are currently in use
-static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
-    for (uint32_t i = cache.size; i > 0; --i) {
-        const llama_kv_cell & cell = cache.cells[i - 1];
-
-        if (cell.pos >= 0 && !cell.is_empty()) {
-            return i;
-        }
-    }
-
-    return 0;
-}
-
-static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
-    for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
-        cache.cells[i].pos = -1;
-        cache.cells[i].seq_id.clear();
-        cache.cells[i].src = -1;
-        cache.cells[i].tail = -1;
-    }
-    cache.head = 0;
-    cache.used = 0;
-
-    for (auto & buf : cache.bufs) {
-        ggml_backend_buffer_clear(buf.get(), 0);
-    }
-}
-
-static bool llama_kv_cache_seq_rm(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    uint32_t new_head = cache.size;
-
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-
-    // models like Mamba or RWKV can't have a state partially erased
-    if (cache.recurrent) {
-        if (seq_id >= (int64_t) cache.size) {
-            // could be fatal
-            return false;
-        }
-        if (0 <= seq_id) {
-            int32_t & tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                const llama_kv_cell & cell = cache.cells[tail_id];
-                // partial intersection is invalid
-                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                    return false;
-                }
-                // invalidate tails which will be cleared
-                if (p0 <= cell.pos && cell.pos < p1) {
-                    tail_id = -1;
-                }
-            }
-        } else {
-            // seq_id is negative, then the range should include everything or nothing
-            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) {
-                return false;
-            }
-        }
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cache.cells[i].seq_id.clear();
-            } else if (cache.cells[i].has_seq_id(seq_id)) {
-                cache.cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-            if (cache.cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cache.cells[i].pos >= 0) cache.used--;
-
-                cache.cells[i].pos = -1;
-                cache.cells[i].src = -1;
-                if (new_head == cache.size) new_head = i;
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
-
-    return true;
-}
-
-static void llama_kv_cache_seq_cp(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id_src,
-                 llama_seq_id   seq_id_dst,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-
-    if (cache.recurrent) {
-        if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
-            llama_kv_cell & tail_src = cache.cells[seq_id_src];
-            llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
-            if (tail_dst.tail >= 0) {
-                // clear destination seq_id if it wasn't empty
-                llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
-
-                cell_dst.seq_id.erase(seq_id_dst);
-                tail_dst.tail = -1;
-                if (cell_dst.seq_id.empty()) {
-                    cell_dst.pos = -1;
-                    cell_dst.delta = -1;
-                    cell_dst.src = -1;
-                    cache.used -= 1;
-                }
-            }
-            if (tail_src.tail >= 0) {
-                llama_kv_cell & cell_src = cache.cells[tail_src.tail];
-
-                cell_src.seq_id.insert(seq_id_dst);
-                tail_dst.tail = tail_src.tail;
-            }
-        }
-
-        return;
-    }
-    // otherwise, this is the KV cache of a Transformer-like model
-
-    cache.head = 0;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].seq_id.insert(seq_id_dst);
-        }
-    }
-}
-
-static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    uint32_t new_head = cache.size;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.recurrent && (llama_seq_id) i != seq_id) {
-            cache.cells[i].tail = -1;
-        }
-        if (!cache.cells[i].has_seq_id(seq_id)) {
-            if (cache.cells[i].pos >= 0) cache.used--;
-            cache.cells[i].pos = -1;
-            cache.cells[i].src = -1;
-            cache.cells[i].seq_id.clear();
-            if (new_head == cache.size) new_head = i;
-        } else {
-            cache.cells[i].seq_id.clear();
-            cache.cells[i].seq_id.insert(seq_id);
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
-}
-
-static void llama_kv_cache_seq_add(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                    llama_pos   delta) {
-    uint32_t new_head = cache.size;
-
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
-
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be shifted
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos += delta;
-                }
-            }
-        }
-        return;
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
-            cache.cells[i].pos   += delta;
-            cache.cells[i].delta += delta;
-
-            if (cache.cells[i].pos < 0) {
-                if (!cache.cells[i].is_empty()) {
-                    cache.used--;
-                }
-                cache.cells[i].pos = -1;
-                cache.cells[i].seq_id.clear();
-                if (new_head == cache.size) {
-                    new_head = i;
-                }
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    // Otherwise we just start the next search from the beginning.
-    cache.head = new_head != cache.size ? new_head : 0;
-}
-
-static void llama_kv_cache_seq_div(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                          int   d) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
-
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be changed
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos /= d;
-                }
-            }
-        }
-        return;
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
-
-            {
-                llama_pos p_old = cache.cells[i].pos;
-                cache.cells[i].pos   /= d;
-                cache.cells[i].delta += cache.cells[i].pos - p_old;
-            }
-        }
-    }
-}
-
-static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    llama_pos result = 0;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cache.cells[i].pos);
-        }
-    }
-
-    return result;
-}
-
-static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
-    if (!cache.recurrent) {
-        cache.do_defrag = true;
-    }
-}
-
-static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
-
-// saves the kv_cache state for future recovery.
-// used to rollback llama_kv_cache_find_slot changes.
-struct llama_kv_slot_restorer {
-    struct llama_kv_cache_state {
-        uint32_t head = 0;
-        uint32_t n    = 0;
-    } old_state;
-
-    // for non-recurrent models only
-    // list of slots to restore
-    std::vector> slot_boundaries;
-
-    bool do_restore = false;
-
-    explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
-        old_state.head  = cache.head;
-        old_state.n     = cache.n;
-    }
-
-    // saves a slot information for future restoration
-    void save(const struct llama_kv_cache_slot_info & slot) {
-        if (slot) {
-            do_restore = true;
-            if (slot.boundaries.first != slot.boundaries.second) {
-                slot_boundaries.push_back(slot.boundaries);
-            }
-        }
-    }
-
-    // must be explicitly called to restore the kv_cache state
-    // and rollback changes from all llama_kv_cache_find_slot calls
-    void restore(struct llama_kv_cache & cache) {
-        if (do_restore) {
-            cache.head  = old_state.head;
-            cache.n     = old_state.n;
-
-            if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
-                llama_kv_cache_seq_rm(cache, -1, -1, -1);
-            } else {
-                for (auto & slot : slot_boundaries) {
-                    llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
-                }
-            }
-        }
-    }
-};
-
-//
-// model loading and saving
-//
-
-enum llama_fver {
-    GGUF_FILE_VERSION_V1 = 1,
-    GGUF_FILE_VERSION_V2 = 2,
-    GGUF_FILE_VERSION_V3 = 3,
-};
-
-static const char * llama_file_version_name(llama_fver version) {
-    switch (version) {
-        case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
-        case GGUF_FILE_VERSION_V2: return "GGUF V2";
-        case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
-    }
-
-    return "unknown";
-}
-
-static std::string llama_format_tensor_shape(const std::vector & ne) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
-    for (size_t i = 1; i < ne.size(); i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
-    }
-    return buf;
-}
-
-static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
-    }
-    return buf;
-}
-
-namespace GGUFMeta {
-    template 
-    struct GKV_Base_Type {
-        static constexpr gguf_type gt = gt_;
-
-        static T getter(const gguf_context * ctx, const int kid) {
-            return gfun(ctx, kid);
-        }
-    };
-
-    template struct GKV_Base;
-
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-
-    template<> struct GKV_Base {
-        static constexpr gguf_type gt = GGUF_TYPE_STRING;
-
-        static std::string getter(const gguf_context * ctx, const int kid) {
-            return gguf_get_val_str(ctx, kid);
-        }
-    };
-
-    struct ArrayInfo {
-        const gguf_type gt;
-        const size_t length;
-        const void * data;
-    };
-
-    template<> struct GKV_Base {
-        public:
-        static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
-        static ArrayInfo getter(const gguf_context *ctx, const int k) {
-            return ArrayInfo {
-                gguf_get_arr_type(ctx, k),
-                size_t(gguf_get_arr_n(ctx, k)),
-                gguf_get_arr_data(ctx, k),
-            };
-        }
-    };
-
-    template
-    class GKV : public GKV_Base {
-        GKV() = delete;
-
-        public:
-        static T get_kv(const gguf_context * ctx, const int k) {
-            const enum gguf_type kt = gguf_get_kv_type(ctx, k);
-
-            if (kt != GKV::gt) {
-                throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
-                    gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
-            }
-            return GKV::getter(ctx, k);
-        }
-
-        static const char * override_type_to_str(const llama_model_kv_override_type ty) {
-            switch (ty) {
-                case LLAMA_KV_OVERRIDE_TYPE_BOOL:  return "bool";
-                case LLAMA_KV_OVERRIDE_TYPE_INT:   return "int";
-                case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
-                case LLAMA_KV_OVERRIDE_TYPE_STR:   return "str";
-            }
-            return "unknown";
-        }
-
-        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
-            if (!ovrd) { return false; }
-            if (ovrd->tag == expected_type) {
-                LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
-                    __func__, override_type_to_str(ovrd->tag), ovrd->key);
-                switch (ovrd->tag) {
-                    case LLAMA_KV_OVERRIDE_TYPE_BOOL:  {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_INT:   {
-                        LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
-                        LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_STR: {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_str);
-                    } break;
-                    default:
-                        // Shouldn't be possible to end up here, but just in case...
-                        throw std::runtime_error(
-                            format("Unsupported attempt to override %s type for metadata key %s\n",
-                                override_type_to_str(ovrd->tag), ovrd->key));
-                }
-                return true;
-            }
-            LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
-                __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
-                target = ovrd->val_bool;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value && std::is_integral::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
-                target = ovrd->val_i64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
-                target = ovrd->val_f64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
-                target = ovrd->val_str;
-                return true;
-            }
-            return false;
-        }
-
-        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            if (try_override(target, ovrd)) {
-                return true;
-            }
-            if (k < 0) { return false; }
-            target = get_kv(ctx, k);
-            return true;
-        }
-
-        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, gguf_find_key(ctx, key), target, ovrd);
-        }
-
-        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, key.c_str(), target, ovrd);
-        }
-    };
-}
-
-using llama_buf_map = std::unordered_map;
-
-static size_t llama_model_max_nodes(const llama_model & model) {
-    return std::max(8192, model.tensors_by_name.size()*5);
-}
-
-struct llama_model_loader {
-    int n_kv      = 0;
-    int n_tensors = 0;
-    int n_created = 0;
-
-    uint64_t n_elements = 0;
-    size_t  n_bytes     = 0;
-
-    bool use_mmap = false;
-    bool check_tensors;
-
-    llama_files files;
-    llama_ftype ftype;
-    llama_fver  fver;
-
-    llama_mmaps mappings;
-
-    // Holds information on a model weight
-    struct llama_tensor_weight {
-        uint16_t  idx; // source file index
-        size_t   offs; // tensor data offset in the original file
-
-        ggml_tensor * tensor;
-
-        llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
-            const int tensor_idx = gguf_find_tensor(gguf_ctx,  ggml_get_name(tensor));
-            if (tensor_idx < 0) {
-                throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
-            }
-
-            offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
-            if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
-                throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
-            }
-        }
-    };
-
-    // custom comparator to sort weights more nicely by layer
-    struct weight_name_comparer {
-        bool operator()(const std::string & a, const std::string & b) const {
-            int a_layer = -1;
-            int b_layer = -1;
-            sscanf(a.c_str(), "blk.%d.", &a_layer);
-            sscanf(b.c_str(), "blk.%d.", &b_layer);
-            if (a_layer != b_layer) {
-                return a_layer < b_layer;
-            }
-            return a < b;
-        }
-    };
-
-    std::map weights_map;
-    std::unordered_map kv_overrides;
-
-    gguf_context_ptr meta;
-    std::vector contexts;
-
-    std::string arch_name;
-    LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
-
-    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
-        int trace = 0;
-        if (getenv("LLAMA_TRACE")) {
-            trace = atoi(getenv("LLAMA_TRACE"));
-        }
-
-        if (param_overrides_p != nullptr) {
-            for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
-                kv_overrides.insert({std::string(p->key), *p});
-            }
-        }
-
-        struct ggml_context * ctx = NULL;
-        struct gguf_init_params params = {
-            /*.no_alloc = */ true,
-            /*.ctx      = */ &ctx,
-        };
-
-        meta.reset(gguf_init_from_file(fname.c_str(), params));
-        if (!meta) {
-            throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
-        }
-
-        get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
-        llm_kv = LLM_KV(llm_arch_from_string(arch_name));
-
-        files.emplace_back(new llama_file(fname.c_str(), "rb"));
-        contexts.emplace_back(ctx);
-
-        // Save tensors data offset of the main file.
-        // For subsidiary files, `meta` tensor data offset must not be used,
-        // so we build a unified tensors index for weights.
-        for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-            std::string tensor_name = std::string(cur->name);
-            // make sure there is no duplicated tensor names
-            if (weights_map.find(tensor_name) != weights_map.end()) {
-                throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
-            }
-            n_elements += ggml_nelements(cur);
-            n_bytes    += ggml_nbytes(cur);
-            weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur));
-        }
-        uint16_t n_split = 0;
-        get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
-
-        // Load additional GGML contexts
-        if (n_split > 1) {
-            uint16_t idx = 0;
-            get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
-            if (idx != 0) {
-                throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
-            }
-
-            char split_prefix[PATH_MAX] = {0};
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) {
-                throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
-            }
-
-            if (trace > 0) {
-                LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
-            }
-
-            char split_path[PATH_MAX] = {0};
-            for (idx = 1; idx < n_split; idx++) {
-                llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
-
-                struct gguf_init_params split_params = {
-                    /*.no_alloc = */ true,
-                    /*.ctx      = */ &ctx,
-                };
-                gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path, split_params) };
-                if (!ctx_gguf) {
-                    throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path));
-                }
-
-                files.emplace_back(new llama_file(split_path, "rb"));
-                contexts.emplace_back(ctx);
-
-                // Save tensors data offset info of the shard.
-                for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-                    std::string tensor_name = std::string(cur->name);
-                    // make sure there is no duplicated tensor names
-                    if (weights_map.find(tensor_name) != weights_map.end()) {
-                        throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
-                    }
-                    n_elements += ggml_nelements(cur);
-                    n_bytes    += ggml_nbytes(cur);
-                    weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
-                }
-            }
-
-            get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
-
-            // sanity check
-            {
-                const int n_tensors_loaded = (int) weights_map.size();
-                if (n_tensors != n_tensors_loaded) {
-                    throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n",  __func__, n_split - 1);
-        }
-
-        n_kv      = gguf_get_n_kv(meta.get());
-        n_tensors = weights_map.size();
-
-        fver = (enum llama_fver) gguf_get_version(meta.get());
-
-        LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
-                __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
-
-        // determine file type based on the number of tensors for each quantization and print meta data
-        // TODO: make optional
-        {
-            std::map n_type;
-
-            uint32_t n_type_max = 0;
-            enum ggml_type type_max = GGML_TYPE_F32;
-
-            for (const auto & it : weights_map) {
-                const llama_tensor_weight & w = it.second;
-                const ggml_tensor * tensor = w.tensor;
-
-                enum ggml_type type = tensor->type;
-
-                n_type[type]++;
-
-                if (n_type_max < n_type[type]) {
-                    n_type_max = n_type[type];
-                    type_max   = type;
-                }
-
-                if (trace > 0) {
-                    const uint16_t sid = w.idx;
-                    LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
-                }
-            }
-
-            switch (type_max) {
-                case GGML_TYPE_F32:     ftype = LLAMA_FTYPE_ALL_F32;        break;
-                case GGML_TYPE_F16:     ftype = LLAMA_FTYPE_MOSTLY_F16;     break;
-                case GGML_TYPE_BF16:    ftype = LLAMA_FTYPE_MOSTLY_BF16;    break;
-                case GGML_TYPE_Q4_0:    ftype = LLAMA_FTYPE_MOSTLY_Q4_0;    break;
-                case GGML_TYPE_Q4_1:    ftype = LLAMA_FTYPE_MOSTLY_Q4_1;    break;
-                case GGML_TYPE_Q5_0:    ftype = LLAMA_FTYPE_MOSTLY_Q5_0;    break;
-                case GGML_TYPE_Q5_1:    ftype = LLAMA_FTYPE_MOSTLY_Q5_1;    break;
-                case GGML_TYPE_Q8_0:    ftype = LLAMA_FTYPE_MOSTLY_Q8_0;    break;
-                case GGML_TYPE_Q2_K:    ftype = LLAMA_FTYPE_MOSTLY_Q2_K;    break;
-                case GGML_TYPE_Q3_K:    ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M;  break;
-                case GGML_TYPE_Q4_K:    ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M;  break;
-                case GGML_TYPE_Q5_K:    ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M;  break;
-                case GGML_TYPE_Q6_K:    ftype = LLAMA_FTYPE_MOSTLY_Q6_K;    break;
-                case GGML_TYPE_TQ1_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ1_0;   break;
-                case GGML_TYPE_TQ2_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ2_0;   break;
-                case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
-                case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
-                case GGML_TYPE_IQ2_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_S;   break;
-                case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
-                case GGML_TYPE_IQ1_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_S;   break;
-                case GGML_TYPE_IQ1_M:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_M;   break;
-                case GGML_TYPE_IQ4_NL:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL;  break;
-                case GGML_TYPE_IQ4_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS;  break;
-                case GGML_TYPE_IQ3_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_S;   break;
-                default:
-                    {
-                        LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
-                        ftype = LLAMA_FTYPE_ALL_F32;
-                    } break;
-            }
-
-            // this is a way to mark that we have "guessed" the file type
-            ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
-
-            {
-                const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
-                if (kid >= 0) {
-                    ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
-
-            for (int i = 0; i < n_kv; i++) {
-                const char * name           = gguf_get_key(meta.get(), i);
-                const enum gguf_type type   = gguf_get_kv_type(meta.get(), i);
-                const std::string type_name =
-                    type == GGUF_TYPE_ARRAY
-                    ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
-                    : gguf_type_name(type);
-
-                std::string value          = gguf_kv_to_str(meta.get(), i);
-                const size_t MAX_VALUE_LEN = 40;
-                if (value.size() > MAX_VALUE_LEN) {
-                    value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
-                }
-                replace_all(value, "\n", "\\n");
-
-                LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
-            }
-
-            // print type counts
-            for (auto & kv : n_type) {
-                if (kv.second == 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
-            }
-        }
-
-        if (!llama_mmap::SUPPORTED) {
-            LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
-            use_mmap = false;
-        }
-
-        this->use_mmap = use_mmap;
-        this->check_tensors = check_tensors;
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const std::string & key, T & result, const bool required = true) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta.get(), kid);
-
-
-        result = arr_info.length;
-        return true;
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr_n(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_arr(const std::string & key, std::vector & result, const bool required = true) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta.get(), kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
-
-        return true;
-    }
-
-    template
-    bool get_arr(const std::string & key, std::array & result, const bool required = true) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta.get(), kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        if (arr_info.length > N_MAX) {
-            throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
-        }
-
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
-
-        return true;
-    }
-
-    template
-    bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_key(const std::string & key, T & result, const bool required = true) {
-        auto it = kv_overrides.find(key);
-
-        const struct llama_model_kv_override * override =
-            it != kv_overrides.end() ? &it->second : nullptr;
-
-        const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override);
-
-        if (required && !found) {
-            throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-        }
-
-        return found;
-    }
-
-    template
-    bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_key(llm_kv(kid), result, required);
-    }
-
-    // get array of n <= N_MAX elements, or a single element repeated n times
-    template
-    bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        if (n > N_MAX) {
-            throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
-        }
-
-        if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) {
-            struct GGUFMeta::ArrayInfo arr_info =
-                GGUFMeta::GKV::get_kv(meta.get(), kid);
-
-            if (n != arr_info.length) {
-                throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
-            }
-
-            return get_arr(key, result, required);
-        } else {
-            T value;
-
-            bool ok = get_key(key, value, required);
-            if (!ok) {
-                return false;
-            }
-
-            for (uint32_t i = 0; i < n; i++) {
-                result[i] = value;
-            }
-
-            return true;
-        }
-    }
-
-    template
-    bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
-        return get_key_or_arr(llm_kv(kid), result, n, required);
-    }
-
-    std::string get_arch_name() const {
-        return arch_name;
-    }
-
-    enum llm_arch get_arch() const {
-        return llm_kv.arch;
-    }
-
-    const llama_tensor_weight * get_weight(const char * name) const {
-        auto pos = weights_map.find(name);
-        if (pos != weights_map.end()) {
-            return &pos->second;
-        }
-
-        return nullptr;
-    }
-
-    const llama_tensor_weight & require_weight(const char * name) const {
-        const llama_tensor_weight * weight = get_weight(name);
-        if (!weight) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return *weight;
-    }
-
-    struct ggml_tensor * get_tensor_meta(const char * name) const {
-        const auto * weight = get_weight(name);
-        if (!weight) {
-            return nullptr;
-        }
-        return weight->tensor;
-    }
-
-    struct ggml_tensor * require_tensor_meta(const std::string & name) const {
-        struct ggml_tensor * tensor = get_tensor_meta(name.c_str());
-        if (!tensor) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
-        }
-        return tensor;
-    }
-
-    const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const {
-        const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
-
-        if (cur == NULL) {
-            if (!required) {
-                return NULL;
-            }
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
-        }
-
-        {
-            bool is_ok = true;
-            for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-                if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
-                    is_ok = false;
-                    break;
-                }
-            }
-            if (!is_ok) {
-                throw std::runtime_error(
-                        format("%s: tensor '%s' has wrong shape; expected %s, got %s",
-                            __func__, name.c_str(),
-                            llama_format_tensor_shape(ne).c_str(),
-                            llama_format_tensor_shape(cur).c_str()));
-            }
-        }
-
-        return cur;
-    }
-
-    static const int TENSOR_NOT_REQUIRED = 1;
-    static const int TENSOR_DUPLICATED   = 2;
-
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        bool duplicated = flags & TENSOR_DUPLICATED;
-
-        struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
-        ggml_set_name(tensor, ggml_get_name(cur));
-
-        if (duplicated) {
-            size_data += ggml_nbytes(cur);
-        } else {
-            n_created++;
-        }
-
-        return tensor;
-
-    }
-
-    struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        if (cur->type != base->type) {
-            throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
-        }
-
-        std::array dims;
-        for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-            dims[i] = i < ne.size() ? ne.begin()[i] : 1;
-        }
-
-        struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
-                                        dims[0], dims[1], dims[2], dims[3],
-                                        cur->nb[1], cur->nb[2], cur->nb[3],
-                                        offset);
-
-        ggml_set_name(tensor, name.c_str());
-
-        n_created++;
-
-        return tensor;
-    }
-
-    void done_getting_tensors() const {
-        if (n_created != n_tensors) {
-            throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
-        }
-    }
-
-    void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr) {
-        if (use_mmap) {
-            mappings.reserve(files.size());
-            mmaps_used.reserve(files.size());
-            for (const auto & file : files) {
-                auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
-                auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
-                std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
-                mmaps_used.emplace_back(mapping->size, 0);
-                if (mlock_mmaps) {
-                    std::unique_ptr mlock_mmap(new llama_mlock());
-                    mlock_mmap->init(mapping->addr);
-                    mlock_mmaps->emplace_back(std::move(mlock_mmap));
-                }
-                mappings.emplace_back(std::move(mapping));
-            }
-        }
-
-        // compute the total size of all tensors for progress reporting
-        for (const auto & it : weights_map) {
-            size_data += ggml_nbytes(it.second.tensor);
-        }
-    }
-
-    void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
-        GGML_ASSERT(!mappings.empty());
-        const auto & mapping = mappings.at(idx);
-
-        *first = mapping->size;
-        *last  = 0;
-        *addr = mapping->addr;
-        for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
-            const auto * weight = get_weight(ggml_get_name(tensor));
-            if (!weight || weight->idx != idx) {
-                continue;
-            }
-            *first = std::min(*first, weight->offs);
-            *last  = std::max(*last,  weight->offs + ggml_nbytes(tensor));
-        }
-    }
-
-    // for backwards compatibility, does not support ggml-backend
-    void load_data_for(struct ggml_tensor * cur) const {
-        const auto & w = require_weight(ggml_get_name(cur));
-
-        if (use_mmap) {
-            const auto & mapping = mappings.at(w.idx);
-            if (cur->data == nullptr) {
-                cur->data = (uint8_t *)mapping->addr + w.offs;
-            } else {
-                memcpy(cur->data, (uint8_t *)mapping->addr + w.offs, ggml_nbytes(cur));
-            }
-        } else {
-            GGML_ASSERT(cur->data != nullptr);
-            GGML_ASSERT(w.idx < files.size());
-            const auto & file = files.at(w.idx);
-            file->seek(w.offs, SEEK_SET);
-            file->read_raw(cur->data, ggml_nbytes(cur));
-        }
-
-        if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
-            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-        }
-    }
-
-    size_t size_done = 0;
-    size_t size_data = 0;
-    std::vector> mmaps_used;
-
-    // Returns false if cancelled by progress_callback
-    bool load_all_data(
-            struct ggml_context * ctx,
-            llama_buf_map & bufs,
-            llama_mlocks * lmlocks,
-            llama_progress_callback progress_callback,
-            void * progress_callback_user_data) {
-        GGML_ASSERT(size_data != 0 && "call init_mappings() first");
-
-        std::vector> read_buf;
-        std::vector>> validation_result;
-
-        // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
-        // NVMe raid configurations might require more / larger buffers.
-        constexpr size_t n_buffers = 4;
-        constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
-
-        std::vector host_buffers;
-        std::vector events;
-        std::vector host_ptrs;
-        size_t buffer_idx = 0; // buffer to use for async loads
-        ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t {
-            if (use_mmap || check_tensors) {
-                return nullptr;
-            }
-            // When not using mmaped io use async uploads from pinned memory to GPU memory.
-            // First determine if the backend supports the necessary features for async uploads.
-            auto * buf = bufs.count(0) ? bufs.at(0) : nullptr;
-            if (!buf) {
-                LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func);
-                return nullptr;
-            }
-
-            auto * buft = ggml_backend_buffer_get_type(buf);
-            auto * dev = ggml_backend_buft_get_device(buft);
-            if (!dev) {
-                LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func,
-                    ggml_backend_buft_name(buft));
-                return nullptr;
-            }
-
-            if (buft != ggml_backend_dev_buffer_type(dev)) {
-                LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func,
-                    ggml_backend_buft_name(buft), ggml_backend_dev_name(dev));
-                return nullptr;
-            }
-
-            ggml_backend_dev_props props;
-            ggml_backend_dev_get_props(dev, &props);
-            if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) {
-                LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func,
-                    ggml_backend_dev_name(dev));
-                return nullptr;
-            }
-
-            auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
-            if (!host_buft) {
-                LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func,
-                    ggml_backend_dev_name(dev));
-                return nullptr;
-            }
-
-            // If the backend is supported, create pinned memory buffers and events for synchronisation.
-            for (size_t idx = 0; idx < n_buffers; ++idx) {
-                auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
-                if (!buf) {
-                    LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
-                        ggml_backend_dev_name(dev));
-                    return nullptr;
-                }
-
-                host_buffers.emplace_back(buf);
-                host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf));
-
-                auto * event = ggml_backend_event_new(dev);
-                if (!event) {
-                    LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func,
-                        ggml_backend_dev_name(dev));
-                    return nullptr;
-                }
-
-                events.emplace_back(event);
-            }
-
-            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
-            if (!backend) {
-                LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func,
-                    ggml_backend_dev_name(dev));
-                return nullptr;
-            }
-
-            return backend;
-        }(__func__);
-
-        if (upload_backend) {
-            LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__,
-                ggml_backend_dev_name(ggml_backend_get_device(upload_backend)),
-                ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))),
-                ggml_backend_name(upload_backend));
-        }
-
-        for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
-            const auto * weight = get_weight(ggml_get_name(cur));
-            if (weight == nullptr) {
-                // this can happen with split experts models
-                continue;
-            }
-
-            if (progress_callback) {
-                if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
-                    return false;
-                }
-            }
-
-            size_t n_size = ggml_nbytes(cur);
-
-            if (use_mmap) {
-                const auto & mapping = mappings.at(weight->idx);
-                ggml_backend_buffer_t buf_mmap = nullptr;
-                if (bufs.count(weight->idx)) {
-                    buf_mmap = bufs.at(weight->idx);
-                }
-                uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
-
-                if (check_tensors) {
-                    validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
-                        return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
-                    }));
-                }
-
-                GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
-                if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, data);
-                    if (lmlocks) {
-                        const auto & lmlock = lmlocks->at(weight->idx);
-                        lmlock->grow_to(weight->offs + n_size);
-                    }
-
-                    auto & mmap_used = mmaps_used[weight->idx];
-                    mmap_used.first  = std::min(mmap_used.first,  weight->offs);
-                    mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
-                } else {
-                    ggml_backend_tensor_set(cur, data, 0, n_size);
-                }
-            } else {
-                const auto & file = files.at(weight->idx);
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(cur->data, n_size);
-                    if (check_tensors) {
-                        validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
-                            return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
-                        }));
-                    }
-                } else {
-                    // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
-                    if (upload_backend) {
-                        file->seek(weight->offs, SEEK_SET);
-
-                        size_t bytes_read = 0;
-
-                        while (bytes_read < n_size) {
-                            size_t read_iteration = std::min(buffer_size, n_size - bytes_read);
-
-                            ggml_backend_event_synchronize(events[buffer_idx]);
-                            file->read_raw(host_ptrs[buffer_idx], read_iteration);
-                            ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
-                            ggml_backend_event_record(events[buffer_idx], upload_backend);
-
-                            bytes_read += read_iteration;
-                            ++buffer_idx;
-                            buffer_idx %= n_buffers;
-                        }
-                    } else {
-                        read_buf.resize(n_size);
-                        file->seek(weight->offs, SEEK_SET);
-                        file->read_raw(read_buf.data(), n_size);
-                        ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
-                        if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
-                            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-                        }
-                    }
-                }
-            }
-
-            size_done += n_size;
-        }
-
-        // free temporary resources used for async uploads
-        for (auto * event : events) {
-            ggml_backend_event_synchronize(event);
-            ggml_backend_event_free(event);
-        }
-        for (auto * buf : host_buffers) {
-            ggml_backend_buffer_free(buf);
-        }
-        ggml_backend_free(upload_backend);
-
-        // check validation results
-        bool validation_failed = false;
-        for (auto & future : validation_result) {
-            auto result = future.get();
-            if (!result.second) {
-                LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
-                validation_failed = true;
-            }
-        }
-        if (validation_failed) {
-            throw std::runtime_error("found tensors with invalid data");
-        }
-
-        // check if this is the last call and do final cleanup
-        if (size_done >= size_data) {
-            // unmap offloaded tensors and metadata
-            if (use_mmap) {
-                for (uint32_t idx = 0; idx < mappings.size(); idx++) {
-                    const auto & mmap_used = mmaps_used.at(idx);
-                    auto & mapping = mappings.at(idx);
-                    mapping->unmap_fragment(0, mmap_used.first);
-                    if (mmap_used.second != 0) {
-                        mapping->unmap_fragment(mmap_used.second, mapping->size);
-                    }
-                }
-            }
-            if (progress_callback) {
-                // Even though the model is done loading, we still honor
-                // cancellation since we need to free allocations.
-                return progress_callback(1.0f, progress_callback_user_data);
-            }
-        }
-
-        return true;
-    }
-};
-
-// temporary allocate memory for the input batch if needed
-static const llama_seq_id batch_default_seq_id = 0;
-struct llama_batch_allocr {
-    std::array seq_id_0 = {batch_default_seq_id};
-    std::vector      pos;
-    std::vector        n_seq_id;
-    std::vector seq_id;
-    std::vector         logits;
-    struct llama_batch          batch;
-    // optionally fulfill the batch returned by llama_batch_get_one
-    llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
-        batch = in_batch;
-        GGML_ASSERT(batch.n_tokens > 0);
-        if (!batch.pos) {
-            // determine the last position in KV cache
-            llama_pos last_pos = -1;
-            for (const auto & cell : ctx.kv_self.cells) {
-                if (cell.has_seq_id(batch_default_seq_id)) {
-                    last_pos = std::max(last_pos, cell.pos);
-                }
-            }
-            last_pos++; // next position
-            pos.resize(batch.n_tokens);
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                pos[i] = i+last_pos;
-            }
-            batch.pos = pos.data();
-        }
-        if (!batch.n_seq_id) {
-            n_seq_id.resize(batch.n_tokens);
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                n_seq_id[i] = seq_id_0.size();
-            }
-            batch.n_seq_id = n_seq_id.data();
-        }
-        if (!batch.seq_id) {
-            seq_id.resize(batch.n_tokens + 1);
-            seq_id[batch.n_tokens] = NULL;
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                seq_id[i] = seq_id_0.data();
-            }
-            batch.seq_id = seq_id.data();
-        }
-        if (!batch.logits) {
-            logits.resize(batch.n_tokens);
-            logits[logits.size() - 1] = true;
-            batch.logits = logits.data();
-        }
-    }
-};
-
-template<>
-bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
-    uint32_t tmp;
-    const bool found = get_key(kid, tmp, required);
-    if (found) {
-        result = (enum llama_pooling_type) tmp;
-    } else {
-        result = LLAMA_POOLING_TYPE_UNSPECIFIED;
-    }
-    return found;
-}
-
-
-//
-// load LLaMA models
-//
-
-static const char * llama_model_arch_name(llm_arch arch) {
-    auto it = LLM_ARCH_NAMES.find(arch);
-    if (it == LLM_ARCH_NAMES.end()) {
-        return "unknown";
-    }
-    return it->second;
-}
-
-static std::string llama_model_ftype_name(llama_ftype ftype) {
-    if (ftype & LLAMA_FTYPE_GUESSED) {
-        return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
-    }
-
-    switch (ftype) {
-        case LLAMA_FTYPE_ALL_F32:         return "all F32";
-        case LLAMA_FTYPE_MOSTLY_F16:      return "F16";
-        case LLAMA_FTYPE_MOSTLY_BF16:     return "BF16";
-        case LLAMA_FTYPE_MOSTLY_Q4_0:     return "Q4_0";
-        case LLAMA_FTYPE_MOSTLY_Q4_1:     return "Q4_1";
-        case LLAMA_FTYPE_MOSTLY_Q5_0:     return "Q5_0";
-        case LLAMA_FTYPE_MOSTLY_Q5_1:     return "Q5_1";
-        case LLAMA_FTYPE_MOSTLY_Q8_0:     return "Q8_0";
-        case LLAMA_FTYPE_MOSTLY_Q2_K:     return "Q2_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q2_K_S:   return "Q2_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_S:   return "Q3_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_M:   return "Q3_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L:   return "Q3_K - Large";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_S:   return "Q4_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M:   return "Q4_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_S:   return "Q5_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M:   return "Q5_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q6_K:     return "Q6_K";
-        case LLAMA_FTYPE_MOSTLY_TQ1_0:    return "TQ1_0 - 1.69 bpw ternary";
-        case LLAMA_FTYPE_MOSTLY_TQ2_0:    return "TQ2_0 - 2.06 bpw ternary";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS:  return "IQ2_XXS - 2.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS:   return "IQ2_XS - 2.3125 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_S:    return "IQ2_S - 2.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_M:    return "IQ2_M - 2.7 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XS:   return "IQ3_XS - 3.3 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS:  return "IQ3_XXS - 3.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_S:    return "IQ1_S - 1.5625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_M:    return "IQ1_M - 1.75 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_NL:   return "IQ4_NL - 4.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_XS:   return "IQ4_XS - 4.25 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_S:    return "IQ3_S - 3.4375 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_M:    return "IQ3_S mix - 3.66 bpw";
-
-        default: return "unknown, may not work";
-    }
-}
-
-static const char * llama_model_type_name(e_model type) {
-    switch (type) {
-        case MODEL_14M:           return "14M";
-        case MODEL_17M:           return "17M";
-        case MODEL_22M:           return "22M";
-        case MODEL_33M:           return "33M";
-        case MODEL_60M:           return "60M";
-        case MODEL_70M:           return "70M";
-        case MODEL_80M:           return "80M";
-        case MODEL_109M:          return "109M";
-        case MODEL_137M:          return "137M";
-        case MODEL_160M:          return "160M";
-        case MODEL_220M:          return "220M";
-        case MODEL_250M:          return "250M";
-        case MODEL_270M:          return "270M";
-        case MODEL_335M:          return "335M";
-        case MODEL_410M:          return "410M";
-        case MODEL_450M:          return "450M";
-        case MODEL_770M:          return "770M";
-        case MODEL_780M:          return "780M";
-        case MODEL_0_5B:          return "0.5B";
-        case MODEL_1B:            return "1B";
-        case MODEL_1_3B:          return "1.3B";
-        case MODEL_1_4B:          return "1.4B";
-        case MODEL_1_5B:          return "1.5B";
-        case MODEL_1_6B:          return "1.6B";
-        case MODEL_2B:            return "2B";
-        case MODEL_2_8B:          return "2.8B";
-        case MODEL_3B:            return "3B";
-        case MODEL_4B:            return "4B";
-        case MODEL_6B:            return "6B";
-        case MODEL_6_9B:          return "6.9B";
-        case MODEL_7B:            return "7B";
-        case MODEL_8B:            return "8B";
-        case MODEL_9B:            return "9B";
-        case MODEL_11B:           return "11B";
-        case MODEL_12B:           return "12B";
-        case MODEL_13B:           return "13B";
-        case MODEL_14B:           return "14B";
-        case MODEL_15B:           return "15B";
-        case MODEL_16B:           return "16B";
-        case MODEL_20B:           return "20B";
-        case MODEL_30B:           return "30B";
-        case MODEL_32B:           return "32B";
-        case MODEL_34B:           return "34B";
-        case MODEL_35B:           return "35B";
-        case MODEL_40B:           return "40B";
-        case MODEL_65B:           return "65B";
-        case MODEL_70B:           return "70B";
-        case MODEL_236B:          return "236B";
-        case MODEL_314B:          return "314B";
-        case MODEL_SMALL:         return "0.1B";
-        case MODEL_MEDIUM:        return "0.4B";
-        case MODEL_LARGE:         return "0.8B";
-        case MODEL_XL:            return "1.5B";
-        case MODEL_A1_7B:         return "A1.7B";
-        case MODEL_A2_7B:         return "A2.7B";
-        case MODEL_8x7B:          return "8x7B";
-        case MODEL_8x22B:         return "8x22B";
-        case MODEL_16x12B:        return "16x12B";
-        case MODEL_10B_128x3_66B: return "10B+128x3.66B";
-        case MODEL_57B_A14B:      return "57B.A14B";
-        case MODEL_27B:           return "27B";
-        default:                  return "?B";
-    }
-}
-
-static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
-    }
-}
-
-static void llm_load_stats(llama_model_loader & ml, llama_model & model) {
-    model.n_elements = ml.n_elements;
-    model.n_bytes = ml.n_bytes;
-}
-
-static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
-    model.arch = ml.get_arch();
-    if (model.arch == LLM_ARCH_UNKNOWN) {
-        throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'");
-    }
-}
-
-static void llm_load_hparams(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & hparams = model.hparams;
-    const gguf_context * ctx = ml.meta.get();
-
-    // get metadata as string
-    for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
-        enum gguf_type type = gguf_get_kv_type(ctx, i);
-        if (type == GGUF_TYPE_ARRAY) {
-            continue;
-        }
-        const char * name = gguf_get_key(ctx, i);
-        const std::string value = gguf_kv_to_str(ctx, i);
-        model.gguf_kv.emplace(name, value);
-    }
-
-    // get general kv
-    ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
-
-    // get hparams kv
-    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
-
-    // everything past this point is not vocab-related
-    if (hparams.vocab_only) {
-        return;
-    }
-
-    ml.get_key(LLM_KV_CONTEXT_LENGTH,    hparams.n_ctx_train);
-    ml.get_key(LLM_KV_EMBEDDING_LENGTH,  hparams.n_embd);
-    ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
-    ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
-    ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
-
-    if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) {
-        ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
-
-        ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd);
-        ml.get_key(LLM_KV_POSNET_BLOCK_COUNT,      hparams.posnet.n_layer);
-
-        ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd);
-        ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT,      hparams.convnext.n_layer);
-    }
-
-    GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
-    GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
-    if (hparams.n_expert > 0) {
-        GGML_ASSERT(hparams.n_expert_used > 0);
-    } else {
-        GGML_ASSERT(hparams.n_expert_used == 0);
-    }
-
-    // zero-out the array hparams
-    std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
-    std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
-    std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
-
-    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
-
-    // n_head_kv is optional, default to n_head
-    hparams.n_head_kv_arr = hparams.n_head_arr;
-
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false);
-
-    bool rope_finetuned = false;
-    ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
-    hparams.rope_finetuned = rope_finetuned;
-
-    hparams.n_ctx_orig_yarn = hparams.n_ctx_train;
-    ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false);
-
-    // rope_freq_base (optional)
-    hparams.rope_freq_base_train = 10000.0f;
-    ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false);
-
-    std::string rope_scaling("linear");
-    ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
-    hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
-    GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
-
-    // rope_freq_scale (inverse of the kv) is optional
-    float ropescale = 0.0f;
-    if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
-        // try the old key name
-        ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
-    }
-    hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
-
-    ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
-
-    // non-transformer models do not have attention heads
-    if (hparams.n_head() > 0) {
-        // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
-        // gpt-j n_rot = rotary_dim
-
-        hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
-        ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
-
-        hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
-        ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
-
-        // sanity check for n_rot (optional)
-        hparams.n_rot = hparams.n_embd_head_k;
-
-        ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
-
-        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
-            if (hparams.n_rot != hparams.n_embd_head_k) {
-                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
-            }
-        }
-    } else {
-        hparams.n_rot = 0;
-        hparams.n_embd_head_k = 0;
-        hparams.n_embd_head_v = 0;
-    }
-
-    // arch-specific KVs
-    switch (model.arch) {
-        case LLM_ARCH_LLAMA:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                if (hparams.n_expert == 8) {
-                    switch (hparams.n_layer) {
-                        case 32: model.type = e_model::MODEL_8x7B; break;
-                        case 56: model.type = e_model::MODEL_8x22B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                } else {
-                    switch (hparams.n_layer) {
-                        case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B
-                        case 22: model.type = e_model::MODEL_1B; break;
-                        case 26: model.type = e_model::MODEL_3B; break;
-                        case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B
-                        // granite uses a vocab with len 49152
-                        case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break;
-                        case 36: model.type = e_model::MODEL_8B; break; // granite
-                        case 40: model.type = e_model::MODEL_13B; break;
-                        case 48: model.type = e_model::MODEL_34B; break;
-                        case 60: model.type = e_model::MODEL_30B; break;
-                        case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                }
-            } break;
-        case LLM_ARCH_DECI:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 80: model.type = e_model::MODEL_70B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_MINICPM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
-                ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
-                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
-
-                switch (hparams.n_layer) {
-                    case 52: model.type = e_model::MODEL_1B; break;
-                    case 40: model.type = e_model::MODEL_2B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_MINICPM3:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
-                ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
-
-                switch (hparams.n_layer) {
-                    case 62: model.type = e_model::MODEL_4B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GROK:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 64: model.type = e_model::MODEL_314B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_FALCON:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 60: model.type = e_model::MODEL_40B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_BAICHUAN:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                if (model.type == e_model::MODEL_13B) {
-                    // TODO: become GGUF KV parameter
-                    hparams.f_max_alibi_bias = 8.0f;
-                }
-            } break;
-        case LLM_ARCH_STARCODER:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 36: model.type = e_model::MODEL_3B; break;
-                    case 42: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_15B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_REFACT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_1B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                // TODO: become GGUF KV parameter
-                hparams.f_max_alibi_bias = 8.0f;
-            } break;
-        case LLM_ARCH_BERT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
-
-                switch (hparams.n_layer) {
-                    case 3:
-                        model.type = e_model::MODEL_17M; break; // bge-micro
-                    case 6:
-                        model.type = e_model::MODEL_22M; break; // MiniLM-L6
-                    case 12:
-                        switch (hparams.n_embd) {
-                            case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
-                            case 768: model.type = e_model::MODEL_109M; break; // bge-base
-                        } break;
-                    case 24:
-                        model.type = e_model::MODEL_335M; break; // bge-large
-                }
-            } break;
-        case LLM_ARCH_JINA_BERT_V2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
-                hparams.f_max_alibi_bias = 8.0f;
-
-                switch (hparams.n_layer) {
-                    case 4:  model.type = e_model::MODEL_33M;  break; // jina-embeddings-small
-                    case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base
-                }
-            } break;
-        case LLM_ARCH_NOMIC_BERT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
-
-                if (hparams.n_layer == 12 && hparams.n_embd == 768) {
-                    model.type = e_model::MODEL_137M;
-                }
-            } break;
-        case LLM_ARCH_BLOOM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 30:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            case 4096: model.type = e_model::MODEL_7B; break;
-                        } break;
-                }
-
-                // TODO: become GGUF KV parameter
-                hparams.f_max_alibi_bias = 8.0f;
-            } break;
-        case LLM_ARCH_MPT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv, false);
-                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 48: model.type = e_model::MODEL_30B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_STABLELM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = e_model::MODEL_12B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_QWEN:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_QWEN2VL:
-            {
-                std::array section_dims;
-                ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, section_dims, 4, true);
-                std::copy(section_dims.begin(), section_dims.begin() + 4, std::begin(hparams.rope_sections));
-            }
-            // fall through
-        case LLM_ARCH_QWEN2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break;
-                    case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 36: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
-                    case 48: model.type = e_model::MODEL_14B; break;
-                    case 64: model.type = e_model::MODEL_32B; break;
-                    case 80: model.type = e_model::MODEL_70B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_QWEN2MOE:
-            {
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
-                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
-
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_A2_7B; break;
-                    case 28: model.type = e_model::MODEL_57B_A14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_PHI2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_PHI3:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
-                if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
-                    // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
-                    hparams.n_swa = 2047;
-                } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-mini-128k-instruct
-                    hparams.n_swa = 262144;
-                } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-medium-128k-instruct
-                    hparams.n_swa = 131072;
-                }
-                bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
-                if (!found_swa && hparams.n_swa == 0) {
-                    throw std::runtime_error("invalid value for sliding_window");
-                }
-            } break;
-        case LLM_ARCH_PLAMO:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_GPT2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 12: model.type = e_model::MODEL_SMALL; break;
-                    case 24: model.type = e_model::MODEL_MEDIUM; break;
-                    case 36: model.type = e_model::MODEL_LARGE; break;
-                    case 48: model.type = e_model::MODEL_XL; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_CODESHELL:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 42: model.type = e_model::MODEL_7B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_ORION:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_INTERNLM2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 48: model.type = e_model::MODEL_20B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GEMMA:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 18: model.type = e_model::MODEL_2B; break;
-                    case 28: model.type = e_model::MODEL_7B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_GEMMA2:
-            {
-                hparams.n_swa = 4096; // default value of gemma 2
-                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
-                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
-                hparams.attn_soft_cap = true;
-
-                switch (hparams.n_layer) {
-                    case 26: model.type = e_model::MODEL_2B; break;
-                    case 42: model.type = e_model::MODEL_9B; break;
-                    case 46: model.type = e_model::MODEL_27B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_STARCODER2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 30: model.type = e_model::MODEL_3B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_15B; break;
-                    case 52: model.type = e_model::MODEL_20B; break; // granite
-                    case 88: model.type = e_model::MODEL_34B; break; // granite
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_MAMBA:
-            {
-                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
-                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
-                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
-                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
-                ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
-
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 24:
-                        switch (hparams.n_embd) {
-                            case 768: model.type = e_model::MODEL_SMALL; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 48:
-                        switch (hparams.n_embd) {
-                            case 1024: model.type = e_model::MODEL_MEDIUM; break;
-                            case 1536: model.type = e_model::MODEL_LARGE; break;
-                            case 2048: model.type = e_model::MODEL_XL; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 64:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_XVERSE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    case 80: model.type = e_model::MODEL_65B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_COMMAND_R:
-            {
-                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_35B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_DBRX:
-        {
-            ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
-            ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv);
-
-            switch (hparams.n_layer) {
-                case 40: model.type = e_model::MODEL_16x12B; break;
-                default: model.type = e_model::MODEL_UNKNOWN;
-            }
-        } break;
-        case LLM_ARCH_OLMO:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,     hparams.f_clamp_kqv, false);
-
-                switch (hparams.n_layer) {
-                    case 22: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 80: model.type = e_model::MODEL_70B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_OLMO2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 16: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_OLMOE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 16: model.type = e_model::MODEL_A1_7B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_OPENELM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                case 16: model.type = e_model::MODEL_270M; break;
-                case 20: model.type = e_model::MODEL_450M; break;
-                case 28: model.type = e_model::MODEL_1B; break;
-                case 36: model.type = e_model::MODEL_3B; break;
-                default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GPTNEOX:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
-                switch (hparams.n_layer) {
-                    case 6:
-                        switch (hparams.n_ff()) {
-                            case 512: model.type = e_model::MODEL_14M; break;
-                            case 2048: model.type = e_model::MODEL_70M; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 12:
-                        switch (hparams.n_ff()) {
-                            case 3072: model.type = e_model::MODEL_160M; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 16:
-                        switch (hparams.n_ff()) {
-                            case 8192: model.type = e_model::MODEL_1B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 24:
-                        switch (hparams.n_ff()) {
-                            case 4096: model.type = e_model::MODEL_410M; break;
-                            case 8192: model.type = e_model::MODEL_1_4B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 32:
-                        switch (hparams.n_ff()) {
-                            case 10240: model.type = e_model::MODEL_2_8B; break;
-                            case 16384: model.type = e_model::MODEL_6_9B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 36:
-                        switch (hparams.n_ff()) {
-                            case 20480: model.type = e_model::MODEL_12B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 44:
-                        switch (hparams.n_ff()) {
-                            case 24576: model.type = e_model::MODEL_20B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_ARCTIC:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                if (hparams.n_expert == 128) {
-                    switch (hparams.n_layer) {
-                        case 35: model.type = e_model::MODEL_10B_128x3_66B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                } else {
-                    model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_DEEPSEEK:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
-
-                switch (hparams.n_layer) {
-                    case 28: model.type = e_model::MODEL_20B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_DEEPSEEK2:
-            {
-                bool is_lite = (hparams.n_layer == 27);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
-                if (!is_lite) {
-                    ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
-                }
-                ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
-
-                switch (hparams.n_layer) {
-                    case 27: model.type = e_model::MODEL_16B; break;
-                    case 60: model.type = e_model::MODEL_236B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_CHATGLM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 28: model.type = e_model::MODEL_6B; break;
-                    case 40: model.type = e_model::MODEL_9B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_BITNET:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 26: model.type = e_model::MODEL_3B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_T5:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
-
-                uint32_t dec_start_token_id;
-                if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) {
-                    hparams.dec_start_token_id = dec_start_token_id;
-                }
-
-                switch (hparams.n_layer) {
-                    case 6:  model.type = e_model::MODEL_60M;  break; // t5-small
-                    case 8:  model.type = e_model::MODEL_80M;  break; // flan-t5-small
-                    case 12:
-                        switch (hparams.n_ff()) {
-                            case 3072: model.type = e_model::MODEL_220M; break; // t5-base
-                            case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 24:
-                        switch (hparams.n_ff()) {
-                            case 4096:  model.type = e_model::MODEL_770M; break; // t5-large
-                            case 2816:  model.type = e_model::MODEL_780M; break; // flan-t5-large
-                            case 16384: model.type = e_model::MODEL_3B;   break; // t5-3b
-                            case 5120:  model.type = e_model::MODEL_3B;   break; // flan-t5-xl
-                            case 65536: model.type = e_model::MODEL_11B;  break; // t5-11b
-                            case 10240: model.type = e_model::MODEL_11B;  break; // flan-t5-xxl
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_T5ENCODER:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
-                model.type = e_model::MODEL_UNKNOWN;
-            } break;
-        case LLM_ARCH_JAIS:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1_3B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    /* TODO: add variants */
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_NEMOTRON:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_4B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_EXAONE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_8B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_RWKV6:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
-                ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
-                ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
-                ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1_6B; break;
-                    case 32:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            case 4096: model.type = e_model::MODEL_7B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 61: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
-                ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
-                ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
-                ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = e_model::MODEL_3B; break;
-                    // Add additional layer/vocab/etc checks here for other model sizes
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_CHAMELEON:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                hparams.f_norm_eps = 1e-5;  // eps for qk-norm, torch default
-                ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 48: model.type = e_model::MODEL_34B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS,    hparams.f_norm_group_eps);
-                ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-            } break;
-        default: (void)0;
-    }
-
-    model.ftype = ml.ftype;
-
-    if (hparams.f_max_alibi_bias > 0.0f) {
-        hparams.use_alibi = true;
-    }
-
-    hparams.rope_type = llama_rope_type(&model);
-}
-
-static void llm_load_vocab(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & vocab = model.vocab;
-
-    struct gguf_context * ctx = ml.meta.get();
-
-    const auto kv = LLM_KV(model.arch);
-
-    // determine vocab type
-    {
-        std::string tokenizer_model;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
-        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
-
-        if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
-            vocab.type = LLAMA_VOCAB_TYPE_NONE;
-
-            // default special tokens
-            vocab.special_bos_id  = LLAMA_TOKEN_NULL;
-            vocab.special_eos_id  = LLAMA_TOKEN_NULL;
-            vocab.special_unk_id  = LLAMA_TOKEN_NULL;
-            vocab.special_sep_id  = LLAMA_TOKEN_NULL;
-            vocab.special_pad_id  = LLAMA_TOKEN_NULL;
-            vocab.special_cls_id  = LLAMA_TOKEN_NULL;
-            vocab.special_mask_id = LLAMA_TOKEN_NULL;
-            vocab.linefeed_id     = LLAMA_TOKEN_NULL;
-
-            // read vocab size from metadata
-            if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) {
-                vocab.n_vocab = 0;
-                LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab);
-            }
-            return;
-        }
-
-        if (tokenizer_model == "llama") {
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
-
-            // default special tokens
-            vocab.special_bos_id  = 1;
-            vocab.special_eos_id  = 2;
-            vocab.special_unk_id  = 0;
-            vocab.special_sep_id  = LLAMA_TOKEN_NULL;
-            vocab.special_pad_id  = LLAMA_TOKEN_NULL;
-            vocab.special_cls_id  = LLAMA_TOKEN_NULL;
-            vocab.special_mask_id = LLAMA_TOKEN_NULL;
-        } else if (tokenizer_model == "bert") {
-            vocab.type = LLAMA_VOCAB_TYPE_WPM;
-
-            // default special tokens
-            vocab.special_bos_id  = LLAMA_TOKEN_NULL;
-            vocab.special_eos_id  = LLAMA_TOKEN_NULL;
-            vocab.special_unk_id  = 100;
-            vocab.special_sep_id  = 102;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = 101;
-            vocab.special_mask_id = 103;
-        } else if (tokenizer_model == "gpt2") {
-            vocab.type = LLAMA_VOCAB_TYPE_BPE;
-
-            // read bpe merges and populate bpe ranks
-            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
-            if (merges_keyidx == -1) {
-                throw std::runtime_error("cannot find tokenizer merges in model file\n");
-            }
-
-            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
-            for (int i = 0; i < n_merges; i++) {
-                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
-                GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-                std::string first;
-                std::string second;
-
-                const size_t pos = word.find(' ', 1);
-
-                if (pos != std::string::npos) {
-                    first  = word.substr(0, pos);
-                    second = word.substr(pos + 1);
-                }
-
-                vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
-            }
-
-            // default special tokens
-            vocab.special_bos_id  = 11;
-            vocab.special_eos_id  = 11;
-            vocab.special_unk_id  = LLAMA_TOKEN_NULL;
-            vocab.special_sep_id  = LLAMA_TOKEN_NULL;
-            vocab.special_pad_id  = LLAMA_TOKEN_NULL;
-            vocab.special_cls_id  = LLAMA_TOKEN_NULL;
-            vocab.special_mask_id = LLAMA_TOKEN_NULL;
-        } else if (tokenizer_model == "t5") {
-            vocab.type = LLAMA_VOCAB_TYPE_UGM;
-
-            // default special tokens
-            vocab.special_bos_id  = LLAMA_TOKEN_NULL;
-            vocab.special_eos_id  = 1;
-            vocab.special_unk_id  = 2;
-            vocab.special_sep_id  = LLAMA_TOKEN_NULL;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = LLAMA_TOKEN_NULL;
-            vocab.special_mask_id = LLAMA_TOKEN_NULL;
-
-            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
-            if (precompiled_charsmap_keyidx != -1) {
-                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
-                const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
-                vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap);
-#ifdef IS_BIG_ENDIAN
-                // correct endiannes of data in precompiled_charsmap binary blob
-                uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0];
-                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
-                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
-                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
-                uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)];
-                for (size_t i = 0; i < xcda_array_size; ++i) {
-                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
-                }
-#endif
-            }
-        } else if (tokenizer_model == "rwkv") {
-            vocab.type = LLAMA_VOCAB_TYPE_RWKV;
-
-            // default special tokens
-            vocab.special_bos_id = LLAMA_TOKEN_NULL;
-            vocab.special_eos_id = LLAMA_TOKEN_NULL;
-            vocab.special_unk_id = LLAMA_TOKEN_NULL;
-            vocab.special_sep_id = LLAMA_TOKEN_NULL;
-            vocab.special_pad_id = LLAMA_TOKEN_NULL;
-        } else {
-            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
-        }
-
-        // for now, only BPE models have pre-tokenizers
-        if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            if (tokenizer_pre.empty()) {
-                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED!        \n", __func__);
-                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (tokenizer_pre == "default") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (
-                    tokenizer_pre == "llama3"   ||
-                    tokenizer_pre == "llama-v3" ||
-                    tokenizer_pre == "llama-bpe"||
-                    tokenizer_pre == "falcon3") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                    tokenizer_pre == "deepseek-llm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "falcon") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
-            } else if (
-                    tokenizer_pre == "mpt") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT;
-            } else if (
-                    tokenizer_pre == "starcoder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
-            } else if (
-                    tokenizer_pre == "gpt-2"   ||
-                    tokenizer_pre == "phi-2"   ||
-                    tokenizer_pre == "jina-es" ||
-                    tokenizer_pre == "jina-de" ||
-                    tokenizer_pre == "gigachat"   ||
-                    tokenizer_pre == "jina-v1-en" ||
-                    tokenizer_pre == "jina-v2-es" ||
-                    tokenizer_pre == "jina-v2-de" ||
-                    tokenizer_pre == "jina-v2-code" ||
-                    tokenizer_pre == "roberta-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
-            } else if (
-                    tokenizer_pre == "refact") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT;
-            } else if (
-                tokenizer_pre == "command-r") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "qwen2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "stablelm2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
-            } else if (
-                tokenizer_pre == "olmo") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
-            } else if (
-                tokenizer_pre == "dbrx") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
-            } else if (
-                tokenizer_pre == "smaug-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
-            } else if (
-                tokenizer_pre == "poro-chat") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "chatglm-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
-                vocab.special_bos_id = LLAMA_TOKEN_NULL;
-            } else if (
-                tokenizer_pre == "viking") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "jais") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
-            } else if (
-                tokenizer_pre == "tekken") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
-                vocab.tokenizer_clean_spaces = false;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                tokenizer_pre == "smollm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "codeshell") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
-            } else if (
-                tokenizer_pre == "bloom") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM;
-            } else if (
-                tokenizer_pre == "gpt3-finnish") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
-            } else if (
-                tokenizer_pre == "exaone") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE;
-            } else if (
-                tokenizer_pre == "chameleon") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
-                vocab.tokenizer_add_bos = true;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "minerva-7b") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA;
-            } else if (
-                tokenizer_pre == "megrez") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-            } else {
-                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
-            }
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = true;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = true;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = false;
-        } else {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-        }
-
-        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      vocab.tokenizer_add_space_prefix,         false);
-        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false);
-    }
-
-    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
-    if (token_idx == -1) {
-        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-    }
-
-    const float * scores = nullptr;
-    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
-    if (score_idx != -1) {
-        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
-    }
-
-    const int * toktypes = nullptr;
-    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
-    if (toktype_idx != -1) {
-        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
-    }
-
-    const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
-
-    vocab.n_vocab = n_vocab;
-    vocab.id_to_token.resize(n_vocab);
-
-    for (uint32_t i = 0; i < n_vocab; i++) {
-        std::string word = gguf_get_arr_str(ctx, token_idx, i);
-
-        //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-        if (word.empty()) {
-            LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
-            word = "[EMPTY_" + std::to_string(i) + "]";
-        }
-
-        vocab.token_to_id[word] = i;
-        vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
-
-        auto & token_data = vocab.id_to_token[i];
-        token_data.text  = std::move(word);
-        token_data.score = scores ? scores[i] : 0.0f;
-        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
-
-        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
-            switch(toktypes[i]) {
-                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
-                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
-                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
-                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
-                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
-                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
-                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-            }
-        }
-    }
-    GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
-
-    vocab.init_tokenizer();
-
-    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-        try {
-            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
-        } catch (const std::exception & e) {
-            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
-            vocab.linefeed_id = vocab.special_pad_id;
-        }
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-        vocab.linefeed_id = vocab.special_pad_id;
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
-        const std::vector ids = llama_tokenize_internal(vocab, "\n", false);
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
-    } else {
-        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
-
-        //GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        if (ids.empty()) {
-            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
-            vocab.linefeed_id = vocab.special_pad_id;
-        } else {
-            vocab.linefeed_id = ids[0];
-        }
-    }
-
-    // special tokens
-    {
-        const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,     vocab.special_bos_id     },
-            { LLM_KV_TOKENIZER_EOS_ID,     vocab.special_eos_id     },
-            { LLM_KV_TOKENIZER_EOT_ID,     vocab.special_eot_id     },
-            { LLM_KV_TOKENIZER_EOM_ID,     vocab.special_eom_id     },
-            { LLM_KV_TOKENIZER_UNK_ID,     vocab.special_unk_id     },
-            { LLM_KV_TOKENIZER_SEP_ID,     vocab.special_sep_id     },
-            { LLM_KV_TOKENIZER_PAD_ID,     vocab.special_pad_id     },
-            { LLM_KV_TOKENIZER_CLS_ID,     vocab.special_cls_id     },
-            { LLM_KV_TOKENIZER_MASK_ID,    vocab.special_mask_id    },
-            { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id },
-            { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id },
-            { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id },
-            { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id },
-            { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id },
-            { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id },
-
-            // deprecated
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id },
-        };
-
-        for (const auto & it : special_token_types) {
-            const std::string & key = kv(std::get<0>(it));
-            int32_t & id = std::get<1>(it);
-
-            uint32_t new_id;
-            if (!ml.get_key(std::get<0>(it), new_id, false)) {
-                continue;
-            }
-            if (new_id >= vocab.id_to_token.size()) {
-                LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
-                    __func__, key.c_str(), new_id, id);
-            } else {
-                id = new_id;
-            }
-        }
-
-        // Handle add_bos_token and add_eos_token
-        {
-            bool temp = true;
-
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.tokenizer_add_bos = temp;
-            }
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.tokenizer_add_eos = temp;
-            }
-        }
-
-        // auto-detect special tokens by text
-        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
-        //       for now, we apply this workaround to find the tokens based on their text
-
-        for (const auto & t : vocab.token_to_id) {
-            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-            if (vocab.special_eot_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|eot_id|>"
-                        || t.first == "<|im_end|>"
-                        || t.first == "<|end|>"
-                        || t.first == ""
-                        || t.first == "<|endoftext|>"
-                        || t.first == ""
-                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
-                   ) {
-                    vocab.special_eot_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find EOM token: "<|eom_id|>"
-            if (vocab.special_eom_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|eom_id|>"
-                        ) {
-                    vocab.special_eom_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-        }
-
-        // maintain a list of tokens that cause end-of-generation
-        // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
-        vocab.special_eog_ids.clear();
-
-        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
-        }
-
-        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
-        }
-
-        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
-        }
-
-        for (const auto & t : vocab.token_to_id) {
-            if (false
-                    || t.first == "<|eot_id|>"
-                    || t.first == "<|im_end|>"
-                    || t.first == "<|end|>"
-                    || t.first == ""
-                    || t.first == "<|endoftext|>"
-                    || t.first == "<|eom_id|>"
-                    || t.first == ""
-               ) {
-                vocab.special_eog_ids.insert(t.second);
-                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.second, t.first.c_str());
-                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                }
-            } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
-                }
-            }
-        }
-
-        // sanity checks
-        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eos_id);
-            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eot_id);
-            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eom_id);
-            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-    }
-
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
-        }
-
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
-
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
-        }
-
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
-                }
-            }
-            return false;
-        };
-
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
-
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
-
-        std::string model_name;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
-
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
-
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
-        }
-    }
-}
-
-static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
-    const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
-
-    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
-
-    auto print_f = [](const std::function & f, uint32_t n) {
-        bool is_var = false;
-
-        std::vector v;
-        for (uint32_t i = 0; i < n; ++i) {
-            v.push_back(f(i));
-            if (v[i] != v[0]) {
-                is_var = true;
-            }
-        }
-
-        std::stringstream ss;
-
-        if (is_var) {
-            ss << "[";
-            for (uint32_t i = 0; i < n; ++i) {
-                ss << v[i];
-                if (i < n - 1) {
-                    ss << ", ";
-                }
-            }
-            ss << "]";
-        } else {
-            ss << v[0];
-        }
-
-        return ss.str();
-    };
-
-    // hparams
-    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
-
-    if (!hparams.vocab_only) {
-        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
-        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
-        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
-        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
-        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
-        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
-        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
-        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
-        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
-        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
-        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
-        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
-        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
-        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
-        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
-        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
-        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
-        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
-        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
-    }
-
-    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
-    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
-    if (ml.n_elements >= 1e12) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
-    } else if (ml.n_elements >= 1e9) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
-    } else if (ml.n_elements >= 1e6) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
-    } else {
-        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
-    }
-    if (ml.n_bytes < GiB) {
-        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
-    } else {
-        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
-    }
-
-    // general kv
-    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
-
-    // special tokens
-    if (vocab.special_bos_id  != -1)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id  != -1)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_eot_id  != -1)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
-    if (vocab.special_eom_id  != -1)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
-    if (vocab.special_unk_id  != -1)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id  != -1)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id  != -1)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id  != -1)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id != -1)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id != -1)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
-
-    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
-    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
-    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
-    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
-    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
-    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
-
-    for (const auto & id : vocab.special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
-    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
-    if (model.arch == LLM_ARCH_DEEPSEEK) {
-        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
-        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
-        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
-        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
-    }
-
-    if (model.arch == LLM_ARCH_DEEPSEEK2) {
-        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
-        LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
-        LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
-        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
-        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
-        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
-        LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
-    }
-
-    if (model.arch == LLM_ARCH_QWEN2MOE) {
-        LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
-        LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
-    }
-
-    if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
-        LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
-        LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
-        LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
-    }
-}
-
-enum llm_tensor_layer {
-    LLM_TENSOR_LAYER_INPUT,
-    LLM_TENSOR_LAYER_REPEATING,
-    LLM_TENSOR_LAYER_OUTPUT,
-};
-
-struct llm_tensor_info {
-    llm_tensor_layer layer;
-    ggml_op op;
-};
-
-static const std::map llm_tensor_info_mapping = {
-    {LLM_TENSOR_TOKEN_EMBD,                 {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_POS_EMBD,                   {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_TOKEN_EMBD_NORM,            {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_TOKEN_TYPES,                {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_OUTPUT,                     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CLS,                        {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CLS_OUT,                    {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_OUTPUT_NORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
-    {LLM_TENSOR_DEC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
-    {LLM_TENSOR_ENC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
-    {LLM_TENSOR_ROPE_FREQS,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
-    {LLM_TENSOR_ROPE_FACTORS_LONG,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
-    {LLM_TENSOR_ROPE_FACTORS_SHORT,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
-    {LLM_TENSOR_ATTN_Q,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_K,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP_SHEXP,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_A,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_K,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP_SHEXP,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_A,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_V,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_OUT,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_CROSS_ATTN_Q,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_CROSS_ATTN_K,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_CROSS_ATTN_V,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_CROSS_ATTN_OUT,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_FFN_GATE,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_FFN_DOWN,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_FFN_UP,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_ATTN_V,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_ATTN_OUT,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_FFN_GATE,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_FFN_DOWN,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ENC_FFN_UP,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE_INP_SHEXP,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE_INP,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_SSM_IN,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_SSM_X,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_SSM_DT,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_SSM_OUT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_W1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_W2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_DECAY_W1,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_DECAY_W2,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_KEY,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_VALUE,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_RECEPTANCE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_GATE,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_TIME_MIX_OUTPUT,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CHANNEL_MIX_KEY,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CHANNEL_MIX_VALUE,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_ACT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
-    {LLM_TENSOR_SSM_CONV1D,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
-    {LLM_TENSOR_SSM_A,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
-    {LLM_TENSOR_SSM_D,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_TIME_MIX_LERP_X,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_TIME_MIX_LN,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_CHANNEL_MIX_LERP_K,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_CHANNEL_MIX_LERP_R,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_TIME_MIX_LERP_W,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_LERP_K,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_LERP_V,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_LERP_R,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_LERP_G,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_DECAY,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
-    {LLM_TENSOR_TIME_MIX_FIRST,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
-    {LLM_TENSOR_ATTN_NORM,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_NORM_2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_OUT_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_POST_NORM,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_FFN_NORM,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_FFN_POST_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_FFN_NORM_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_Q_NORM,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_K_NORM,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_LAYER_OUT_NORM,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_Q_A_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_KV_A_NORM,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ATTN_SUB_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_FFN_SUB_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_DEC_ATTN_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_DEC_CROSS_ATTN_NORM,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_DEC_FFN_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ENC_ATTN_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_ENC_FFN_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_DEC_ATTN_REL_B,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_ENC_ATTN_REL_B,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_FFN_DOWN_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
-    {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
-    {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
-    // this tensor is loaded for T5, but never used
-    {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
-    {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
-    {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_POS_NET_NORM2,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_POS_NET_CONV1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
-    {LLM_TENSOR_POS_NET_CONV2,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
-    {LLM_TENSOR_POS_NET_ATTN_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_POS_NET_ATTN_Q,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_POS_NET_ATTN_K,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_POS_NET_ATTN_V,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_POS_NET_ATTN_OUT,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CONVNEXT_DW,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
-    {LLM_TENSOR_CONVNEXT_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-    {LLM_TENSOR_CONVNEXT_PW1,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CONVNEXT_PW2,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_CONVNEXT_GAMMA,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
-};
-
 // checks if the weight tensor can be used with the specified buffer type and device
 static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
     GGML_ASSERT(w != nullptr);
@@ -7841,11 +414,12 @@ static bool llm_load_tensors(
                 tn_tensor = LLM_TENSOR_OUTPUT;
             }
 
-            auto it = llm_tensor_info_mapping.find(tn_tensor);
-            if (it == llm_tensor_info_mapping.end()) {
+            llm_tensor_info info;
+            try {
+                info = llm_tensor_info_for(tn_tensor);
+            } catch (const std::out_of_range & e) {
                 throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
             }
-            const auto & info = it->second;
 
             // tensors with "bias" suffix are always used with GGML_OP_ADD
             ggml_op op;
@@ -8638,6 +1212,50 @@ static bool llm_load_tensors(
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
                         layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
 
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                    }
+                } break;
+            case LLM_ARCH_PHIMOE:
+                {
+                    const int64_t n_embd_head = n_embd / n_head;
+
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), { n_embd, n_vocab }, 0);
+                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   { n_vocab }, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), { n_embd }, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        if (layer.wqkv == nullptr) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias",   i), {n_embd}, 0);
+
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
+
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
+                        }
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), { n_embd }, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), { n_embd }, 0);
+
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert},         0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+
                         layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                         layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                     }
@@ -8970,6 +1588,32 @@ static bool llm_load_tensors(
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_COHERE2:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    // init output from the input tok embed
+                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
+                                                      llama_model_loader::TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+                    }
+                }
+                break;
             case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
                 {
                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -9249,6 +1893,7 @@ static bool llm_load_tensors(
                             layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                         } else {
                             layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                             if (n_expert == 0) {
                                 throw std::runtime_error("n_expert must be > 0");
@@ -9931,21 +2576,36 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
-    if (batch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+    if (ubatch.token) {
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
         cb(lctx.inp_tokens, "inp_tokens", -1);
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+
+        // apply lora for embedding tokens if needed
+        for (auto & it : lctx.lora_adapters) {
+            struct llama_lora_weight * lora = it.first->get_weight(tok_embd);
+            if (lora == nullptr) {
+                continue;
+            }
+            const float adapter_scale = it.second;
+            const float scale = lora->get_scale(it.first->alpha, adapter_scale);
+            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
+                ctx, lora->b, // non-transposed lora_b
+                ggml_get_rows(ctx, lora->a, lctx.inp_tokens)
+            ), scale);
+            inpL = ggml_add(ctx, inpL, inpL_delta);
+        }
     } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -10016,9 +2676,8 @@ static struct ggml_tensor * llm_build_lora_mm(
         if (lora == nullptr) {
             continue;
         }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
+        const float adapter_scale = it.second;
+        const float scale = lora->get_scale(it.first->alpha, adapter_scale);
         struct ggml_tensor * ab_cur = ggml_mul_mat(
             ctx0, lora->b,
             ggml_mul_mat(ctx0, lora->a, cur)
@@ -10229,12 +2888,14 @@ static struct ggml_tensor * llm_build_moe_ffn(
          struct ggml_tensor * up_exps,
          struct ggml_tensor * gate_exps,
          struct ggml_tensor * down_exps,
+         struct ggml_tensor * exp_probs_b,
                     int64_t   n_expert,
                     int64_t   n_expert_used,
             llm_ffn_op_type   type_op,
                        bool   norm_w,
                        bool   scale_w,
                       float   w_scale,
+llama_expert_gating_func_type gating_op,
          const llm_build_cb & cb,
                         int   il) {
     int64_t n_embd = cur->ne[0];
@@ -10243,11 +2904,31 @@ static struct ggml_tensor * llm_build_moe_ffn(
     ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
-    ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+    ggml_tensor * probs = nullptr;
+    switch (gating_op) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
+            {
+                probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
+            {
+                probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
     cb(probs, "ffn_moe_probs", il);
 
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
     // select experts
-    ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+    ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
     cb(selected_experts->src[0], "ffn_moe_argsort", il);
     cb(selected_experts, "ffn_moe_topk", il);
 
@@ -10518,7 +3199,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
 static struct ggml_tensor * llm_build_mamba(
         struct ggml_context * ctx,
        struct llama_context & lctx,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_cgraph * graph,
          struct ggml_tensor * cur,
          struct ggml_tensor * state_copy,
@@ -10534,17 +3215,17 @@ static struct ggml_tensor * llm_build_mamba(
     const int64_t d_inner = hparams.ssm_d_inner;
     const int64_t d_state = hparams.ssm_d_state;
     const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = batch.n_seqs;
+    const int64_t n_seqs  = ubatch.n_seqs;
     // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
     const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
     // Use the same RMS norm as the final layer norm
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
-    const int64_t n_seq_tokens = batch.n_seq_tokens;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
     GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(batch.equal_seqs);
-    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     struct ggml_tensor * conv_states_all = kv.k_l[il];
     struct ggml_tensor * ssm_states_all  = kv.v_l[il];
@@ -11344,6 +4025,7 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.layers[il].ffn_gate_inp == nullptr) {
+
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -11368,9 +4050,11 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, true,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
                 cb(cur, "ffn_moe_out", il);
             }
@@ -12020,9 +4704,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_GELU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -12161,9 +4847,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -13409,9 +6097,11 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, false,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -13620,7 +6310,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -13635,8 +6325,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -13680,14 +6369,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -13695,6 +6382,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -13707,11 +6408,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -14644,9 +7350,9 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
-                    case e_model::MODEL_2B:
-                    case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case llm_type::MODEL_2B:
+                    case llm_type::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case llm_type::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
@@ -15051,6 +7757,137 @@ struct llm_build_context {
 
     }
 
+    struct ggml_cgraph * build_cohere2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // cohere2 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+
+        // sliding window switch pattern
+        const int32_t sliding_window_pattern = 4;
+
+        for (int il = 0; il < n_layer; ++il) {
+            // three layers sliding window attention (window size 4096) and ROPE
+            // fourth layer uses global attention without positional embeddings
+            const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+            struct ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (is_sliding) {
+                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                                        beta_fast, beta_slow);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                                        rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                                        attn_factor, beta_fast, beta_slow);
+                    cb(Kcur, "Kcur", il);
+                } else {
+                    // For non-sliding layers, just reshape without applying RoPE
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
+                                   KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur                              = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpL                             = ggml_get_rows(ctx0, inpL, inp_out_ids);
+                ffn_inp                          = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            struct ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+                                    NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+                                    cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://allenai.org/olmo
     // based on the original build_llama() function, changes:
     //   * non-parametric layer norm
@@ -15403,9 +8240,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, false,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -15800,9 +8639,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -15941,9 +8782,11 @@ struct llm_build_context {
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
                             model.layers[il].ffn_down_exps,
+                            nullptr,
                             n_expert, n_expert_used,
                             LLM_FFN_SILU, false,
                             false, hparams.expert_weights_scale,
+                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                             cb, il);
                 cb(moe_out, "ffn_moe_out", il);
 
@@ -16170,9 +9013,11 @@ struct llm_build_context {
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
                             model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
                             n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
                             true, hparams.expert_weights_scale,
+                            (enum llama_expert_gating_func_type) hparams.expert_gating_func,
                             cb, il);
                 cb(moe_out, "ffn_moe_out", il);
 
@@ -17751,6 +10596,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;
@@ -17802,6 +10648,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_command_r();
             } break;
+        case LLM_ARCH_COHERE2:
+            {
+                result = llm.build_cohere2();
+            } break;
         case LLM_ARCH_DBRX:
             {
                 result = llm.build_dbrx();
@@ -17896,572 +10746,6 @@ static struct ggml_cgraph * llama_build_graph(
     return result;
 }
 
-static void llama_set_k_shift(llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
-
-    assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
-
-    int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].delta;
-    }
-}
-
-static void llama_set_s_copy(llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
-
-    assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
-
-    int32_t * data = (int32_t *) lctx.inp_s_copy->data;
-
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].src;
-    }
-}
-
-static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
-    // TODO move to hparams if a T5 variant appears that uses a different value
-    const int64_t max_distance = 128;
-
-    if (bidirectional) {
-        n_buckets >>= 1;
-    }
-
-    const int64_t max_exact = n_buckets >> 1;
-
-    int32_t relative_position = x - y;
-    int32_t relative_bucket = 0;
-    if (bidirectional) {
-        relative_bucket += (relative_position > 0) * n_buckets;
-        relative_position = abs(relative_position);
-    } else {
-        relative_position = -std::min(relative_position, 0);
-    }
-    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
-    relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1);
-    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
-    return relative_bucket;
-}
-
-static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
-    //
-    // set input data
-    //
-
-    const auto & hparams = lctx.model.hparams;
-    const auto & cparams = lctx.cparams;
-    const auto & kv_self = lctx.kv_self;
-
-    if (ubatch.token) {
-        const int64_t n_tokens = ubatch.n_tokens;
-
-        ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
-    }
-
-    if (ubatch.embd) {
-        const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = ubatch.n_tokens;
-
-        ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
-    }
-
-    if (ubatch.pos && lctx.inp_pos) {
-        const int64_t n_tokens = ubatch.n_tokens;
-        auto n_pos = lctx.n_pos_per_token;
-        ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
-    }
-
-    if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
-        //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-
-        if (!lctx.inp_out_ids) {
-            LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
-        } else {
-            const int64_t n_tokens = ubatch.n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
-            int32_t * data = (int32_t *) lctx.inp_out_ids->data;
-
-            if (lctx.n_outputs == n_tokens) {
-                for (int i = 0; i < n_tokens; ++i) {
-                    data[i] = i;
-                }
-            } else if (ubatch.output) {
-                int32_t n_outputs = 0;
-                for (int i = 0; i < n_tokens; ++i) {
-                    if (ubatch.output[i]) {
-                        data[n_outputs++] = i;
-                    }
-                }
-                // the graph needs to have been passed the correct number of outputs
-                GGML_ASSERT(lctx.n_outputs == n_outputs);
-            } else if (lctx.n_outputs == 1) {
-                // only keep last output
-                data[0] = n_tokens - 1;
-            } else {
-                GGML_ASSERT(lctx.n_outputs == 0);
-            }
-        }
-    }
-
-    GGML_ASSERT(
-        // (!a || b) is a logical implication (a -> b)
-        // !hparams.causal_attn -> !cparams.causal_attn
-        (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention is not supported by this model"
-    );
-
-    if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
-        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn && !lctx.is_encoding) {
-            const int64_t n_kv         = kv_self.n;
-            const int64_t n_tokens     = ubatch.n_tokens;
-            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-            const int64_t n_seqs       = ubatch.n_seqs;
-
-
-            float * data     = nullptr;
-            float * data_swa = nullptr;
-
-            if (lctx.inp_KQ_mask) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-                data = (float *) lctx.inp_KQ_mask->data;
-            }
-
-            if (lctx.inp_KQ_mask_swa) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
-                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
-            }
-
-            // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the ubatch.
-            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-            for (int h = 0; h < 1; ++h) {
-                for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
-
-                        for (int i = 0; i < n_kv; ++i) {
-                            float f;
-                            if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
-                                f = -INFINITY;
-                            } else {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(kv_self.cells[i].pos - pos);
-                                } else {
-                                    f = 0.0f;
-                                }
-                            }
-
-                            if (data) {
-                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
-
-                            // may need to cut off old tokens for sliding window
-                            if (data_swa) {
-                                if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
-                        }
-                    }
-                }
-
-                if (data) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
-
-                if (data_swa) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
-            }
-        } else {
-            const int64_t n_tokens     = ubatch.n_tokens;
-            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-            const int64_t n_seqs       = ubatch.n_seqs;
-            // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-
-            float * data = (float *) lctx.inp_KQ_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = ubatch.seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
-                                    if (ubatch.seq_id[s0][s] == seq_id) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
-                            }
-                        }
-
-                        for (int i = n_tokens; i < n_stride; ++i) {
-                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_mean);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
-
-        float * data = (float *) lctx.inp_mean->data;
-        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
-
-        std::vector sum(n_tokens, 0);
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
-
-            sum[seq_id] += ubatch.n_seq_tokens;
-        }
-
-        std::vector div(n_tokens, 0.0f);
-        for (int i = 0; i < n_tokens; ++i) {
-            const uint64_t s = sum[i];
-            if (s > 0) {
-                div[i] = 1.0f/float(s);
-            }
-        }
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
-            }
-        }
-    }
-
-    if (cparams.embeddings && (
-                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
-                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
-
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
-
-                if (pos == 0) {
-                    data[seq_id] = s*n_seq_tokens + i;
-                }
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
-
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
-
-        std::vector last_pos(n_tokens, -1);
-        std::vector last_row(n_tokens, -1);
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
-
-                if (pos >= last_pos[seq_id]) {
-                    last_pos[seq_id] = pos;
-                    last_row[seq_id] = s*n_seq_tokens + i;
-                }
-            }
-        }
-
-        for (int i = 0; i < n_tokens; ++i) {
-            if (last_row[i] >= 0) {
-                data[i] = last_row[i];
-            }
-        }
-    }
-
-    if (kv_self.recurrent) {
-        const int64_t n_kv = kv_self.n;
-
-        if (lctx.inp_s_mask) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
-            float * data = (float *) lctx.inp_s_mask->data;
-
-            // clear unused states
-            for (int i = 0; i < n_kv; ++i) {
-                const uint32_t  cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
-
-                data[i] = (float) (kv_cell.src >= 0);
-
-                // only clear once
-                if (kv_cell.src < 0) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
-
-        if (lctx.inp_s_copy) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
-            int32_t * data = (int32_t *) lctx.inp_s_copy->data;
-
-            // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                const uint32_t  cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
-
-                // prevent out-of-bound sources
-                if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
-                    kv_cell.src = cell_id;
-                }
-
-                data[i] = kv_cell.src;
-
-                // ensure copy only happens once
-                if (kv_cell.src != (int32_t) cell_id) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
-    }
-
-    if (lctx.inp_pos_bucket) {
-        const int64_t n_tokens = ubatch.n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
-        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
-
-        int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
-
-        if (!lctx.is_encoding) {
-            const int64_t n_kv = kv_self.n;
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_kv; ++i) {
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
-                }
-            }
-        } else {
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_tokens; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
-                }
-            }
-        }
-    }
-
-    if (!lctx.is_encoding && lctx.inp_embd_enc) {
-        assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
-        assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
-
-        ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
-    }
-
-    if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
-        const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
-        const int64_t n_tokens = ubatch.n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
-        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
-
-        float * data = (float *) lctx.inp_KQ_mask_cross->data;
-
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_output_enc; ++i) {
-                    float f = -INFINITY;
-                    for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = ubatch.seq_id[j][s];
-                        if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
-                            f = 0.0f;
-                        }
-                    }
-                    data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
-                }
-            }
-
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (int j = 0; j < n_output_enc; ++j) {
-                    data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
-                }
-            }
-        }
-    }
-}
-
-// Make sure enough space is available for outputs.
-// Returns max number of outputs for which space was reserved.
-static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
-    const auto & cparams = lctx.cparams;
-    const auto & hparams = lctx.model.hparams;
-
-    const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
-
-    const auto n_batch = cparams.n_batch;
-    const auto n_vocab = hparams.n_vocab;
-    const auto n_embd  = hparams.n_embd;
-
-    // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = !cparams.embeddings;
-    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
-
-    const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-    const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
-
-    if (lctx.output_ids.empty()) {
-        // init, never resized afterwards
-        lctx.output_ids.resize(n_batch);
-    }
-
-    const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
-    const size_t new_size  = (logits_size + embd_size) * sizeof(float);
-
-    // alloc only when more than the current capacity is required
-    // TODO: also consider shrinking the buffer
-    if (!lctx.buf_output || prev_size < new_size) {
-        if (lctx.buf_output) {
-#ifndef NDEBUG
-            // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
-            LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
-#endif
-            lctx.buf_output = nullptr;
-            lctx.logits = nullptr;
-            lctx.embd = nullptr;
-        }
-
-        auto * buft = ggml_backend_cpu_buffer_type();
-        // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
-        auto * output_dev = lctx.model.dev_output.dev;
-        auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
-        if (output_dev_host_buft) {
-            buft = output_dev_host_buft;
-        }
-        lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
-        if (lctx.buf_output == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
-            return 0;
-        }
-    }
-
-    float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
-
-    lctx.logits = has_logits ? output_base               : nullptr;
-    lctx.embd   = has_embd   ? output_base + logits_size : nullptr;
-
-    lctx.output_size = n_outputs_max;
-    lctx.logits_size = logits_size;
-    lctx.embd_size   = embd_size;
-
-    // set all ids as invalid (negative)
-    std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
-
-    ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
-
-    lctx.n_outputs = 0;
-
-    return n_outputs_max;
-}
-
-// make the outputs have the same order they had in the user-provided batch
-static void llama_output_reorder(struct llama_context * ctx) {
-    std::vector & out_ids = ctx->sbatch.out_ids;
-    if (!out_ids.empty()) {
-        uint32_t n_vocab = ctx->model.hparams.n_vocab;
-        uint32_t n_embd  = ctx->model.hparams.n_embd;
-        int32_t n_outputs = ctx->n_outputs;
-        GGML_ASSERT((size_t) n_outputs == out_ids.size());
-        // TODO: is there something more efficient which also minimizes swaps?
-        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
-        for (int32_t i = 0; i < n_outputs - 1; ++i) {
-            int32_t j_min = i;
-            for (int32_t j = i + 1; j < n_outputs; ++j) {
-                if (out_ids[j] < out_ids[j_min]) {
-                    j_min = j;
-                }
-            }
-            if (j_min == i) { continue; }
-            std::swap(out_ids[i], out_ids[j_min]);
-            if (ctx->logits_size > 0) {
-                for (uint32_t k = 0; k < n_vocab; k++) {
-                    std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
-                }
-            }
-            if (ctx->embd_size > 0) {
-                for (uint32_t k = 0; k < n_embd; k++) {
-                    std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
-                }
-            }
-        }
-        std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
-        for (int32_t i = 0; i < n_outputs; ++i) {
-            ctx->output_ids[out_ids[i]] = i;
-        }
-        out_ids.clear();
-    }
-}
-
 // returns the result of ggml_backend_sched_graph_compute_async execution
 static enum ggml_status llama_graph_compute(
           llama_context & lctx,
@@ -18501,7 +10785,7 @@ static enum ggml_status llama_graph_compute(
 // return positive int on warning
 // return negative int on error
 //
-static int llama_decode_internal(
+static int llama_decode_impl(
          llama_context & lctx,
            llama_batch   inp_batch) {
 
@@ -18513,7 +10797,8 @@ static int llama_decode_internal(
     }
 
     // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+
     const llama_batch & batch = batch_allocr.batch;
     const uint32_t n_tokens_all = batch.n_tokens;
 
@@ -18835,7 +11120,7 @@ static int llama_decode_internal(
 // return positive int on warning
 // return negative int on error
 //
-static int llama_encode_internal(
+static int llama_encode_impl(
          llama_context & lctx,
            llama_batch   inp_batch) {
 
@@ -18847,7 +11132,8 @@ static int llama_encode_internal(
     }
 
     // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+
     const llama_batch & batch = batch_allocr.batch;
     const uint32_t n_tokens = batch.n_tokens;
 
@@ -19016,7 +11302,7 @@ static int llama_encode_internal(
 }
 
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
 
     const auto & hparams = lctx.model.hparams;
@@ -19236,7 +11522,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
 }
 
-static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+static void llama_kv_cache_update_impl(struct llama_context & lctx) {
     bool need_reserve = false;
 
     if (lctx.kv_self.has_shift) {
@@ -19272,7 +11558,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_internal(lctx);
+        llama_kv_cache_defrag_impl(lctx);
 
         need_reserve = true;
 
@@ -19297,1054 +11583,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     }
 }
 
-//
-// quantization
-//
-
-struct quantize_state_internal {
-    const llama_model                 & model;
-    const llama_model_quantize_params * params;
-
-    int n_attention_wv    = 0;
-    int n_ffn_down        = 0;
-    int n_ffn_gate        = 0;
-    int n_ffn_up          = 0;
-    int i_attention_wv    = 0;
-    int i_ffn_down        = 0;
-    int i_ffn_gate        = 0;
-    int i_ffn_up          = 0;
-
-    int n_k_quantized     = 0;
-    int n_fallback        = 0;
-
-    bool has_imatrix      = false;
-
-    // used to figure out if a model shares tok_embd with the output weight
-    bool has_output       = false;
-
-    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
-        : model(model)
-        , params(params)
-        {}
-};
-
-static void llama_tensor_dequantize_internal(
-    struct ggml_tensor * tensor, std::vector> & output, std::vector & workers,
-    const size_t nelements, const int nthread
-) {
-    if (output.size() < nelements) {
-        output.resize(nelements);
-    }
-    float * f32_output = (float *) output.data();
-
-    const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
-    if (ggml_is_quantized(tensor->type)) {
-        if (qtype->to_float == NULL) {
-            throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
-        }
-    } else if (tensor->type != GGML_TYPE_F16 &&
-               tensor->type != GGML_TYPE_BF16) {
-        throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
-    }
-
-    if (nthread < 2) {
-        if (tensor->type == GGML_TYPE_F16) {
-            ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
-        } else if (tensor->type == GGML_TYPE_BF16) {
-            ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
-        } else if (ggml_is_quantized(tensor->type)) {
-            qtype->to_float(tensor->data, f32_output, nelements);
-        } else {
-            GGML_ABORT("fatal error"); // unreachable
-        }
-        return;
-    }
-
-    size_t block_size;
-    if (tensor->type == GGML_TYPE_F16 ||
-        tensor->type == GGML_TYPE_BF16) {
-        block_size = 1;
-    } else {
-        block_size = (size_t)ggml_blck_size(tensor->type);
-    }
-
-    size_t block_size_bytes = ggml_type_size(tensor->type);
-
-    GGML_ASSERT(nelements % block_size == 0);
-    size_t nblocks = nelements / block_size;
-    size_t blocks_per_thread = nblocks / nthread;
-    size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
-
-    size_t in_buff_offs = 0;
-    size_t out_buff_offs = 0;
-
-    for (int tnum = 0; tnum < nthread; tnum++) {
-        size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
-        size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
-        size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
-
-        auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
-            if (typ == GGML_TYPE_F16) {
-                ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
-            } else if (typ == GGML_TYPE_BF16) {
-                ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
-            } else {
-                qtype->to_float(inbuf, outbuf, nels);
-            }
-        };
-        workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
-        in_buff_offs += thr_block_bytes;
-        out_buff_offs += thr_elems;
-    }
-    for (auto & w : workers) { w.join(); }
-    workers.clear();
-}
-
-static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
-    const std::string name = ggml_get_name(tensor);
-
-    // TODO: avoid hardcoded tensor names - use the TN_* constants
-    const llm_arch arch = qs.model.arch;
-    const auto       tn = LLM_TN(arch);
-
-    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
-        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
-    };
-    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
-    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
-        if (n_expert > 1) {
-            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
-            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
-            // for getting the current layer as I initially thought, and we need to resort to parsing the
-            // tensor name.
-            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
-                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
-            }
-            if (i_layer < 0 || i_layer >= n_layer) {
-                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
-            }
-        }
-        return std::make_pair(i_layer, n_layer);
-    };
-
-    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
-    // with the quantization of the output tensor
-    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
-        if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
-            new_type = qs.params->output_tensor_type;
-        } else {
-            int nx = tensor->ne[0];
-            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
-                new_type = GGML_TYPE_Q8_0;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
-                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S  || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M   ||
-                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-                new_type = GGML_TYPE_Q5_K;
-            }
-            else if (new_type != GGML_TYPE_Q8_0) {
-                new_type = GGML_TYPE_Q6_K;
-            }
-        }
-    } else if (name == "token_embd.weight") {
-        if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
-            new_type = qs.params->token_embedding_type;
-        } else {
-            if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
-                ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-                new_type = GGML_TYPE_Q2_K;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
-                new_type = GGML_TYPE_IQ3_S;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-                new_type = GGML_TYPE_IQ3_S;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
-                new_type = GGML_TYPE_Q4_K;
-            }
-        }
-    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
-               ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-        if (name.find("attn_v.weight") != std::string::npos) {
-            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
-            else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
-            ++qs.i_attention_wv;
-        }
-        else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (name.find("ffn_down") != std::string::npos) {
-            if (qs.i_ffn_down < qs.n_ffn_down/8) {
-                new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
-            }
-            ++qs.i_ffn_down;
-        }
-        else if (name.find("attn_output.weight") != std::string::npos) {
-            if (qs.model.hparams.n_expert == 8) {
-                new_type = GGML_TYPE_Q5_K;
-            } else {
-                if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
-            }
-        }
-    } else if (name.find("attn_v.weight") != std::string::npos) {
-        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
-            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
-            new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
-                use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
-        if (qs.model.type == MODEL_70B) {
-            // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
-            // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
-            // nearly negligible increase in model size by quantizing this tensor with more bits:
-            if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
-        }
-        if (qs.model.hparams.n_expert == 8) {
-            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
-            // TODO: explore better strategies
-            new_type = GGML_TYPE_Q8_0;
-        }
-        ++qs.i_attention_wv;
-    } else if (name.find("attn_k.weight") != std::string::npos) {
-        if (qs.model.hparams.n_expert == 8) {
-            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
-            // TODO: explore better strategies
-            new_type = GGML_TYPE_Q8_0;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = GGML_TYPE_IQ2_S;
-        }
-    } else if (name.find("attn_q.weight") != std::string::npos) {
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = GGML_TYPE_IQ2_S;
-        }
-    } else if (name.find("ffn_down") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
-            if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
-            new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
-            new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
-                     : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
-                     : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
-                    (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
-            new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
-            if (arch == LLM_ARCH_FALCON) {
-                new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
-                           use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
-            } else {
-                if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
-            }
-        }
-        else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
-                && qs.has_imatrix && i_layer < n_layer/8) {
-            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
-            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
-            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
-            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
-        }
-        ++qs.i_ffn_down;
-    } else if (name.find("attn_output.weight") != std::string::npos) {
-        if (arch != LLM_ARCH_FALCON) {
-            if (qs.model.hparams.n_expert == 8) {
-                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
-                    ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL  ||
-                    ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S  ||
-                    ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
-                    new_type = GGML_TYPE_Q5_K;
-                }
-            } else {
-                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   ) new_type = GGML_TYPE_Q3_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  ) new_type = GGML_TYPE_Q4_K;
-            }
-        } else {
-            if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
-        }
-    }
-    else if (name.find("attn_qkv.weight") != std::string::npos) {
-        if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
-    }
-    else if (name.find("ffn_gate") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        ++qs.i_ffn_gate;
-    }
-    else if (name.find("ffn_up") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        ++qs.i_ffn_up;
-    }
-
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-    //}
-    // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
-    //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-    //}
-    // This can be used to reduce the size of the Q5_K_S model.
-    // The associated PPL increase is fully in line with the size reduction
-    //else {
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
-    //}
-    bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
-        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
-        new_type == GGML_TYPE_IQ1_M) {
-        int nx = tensor->ne[0];
-        int ny = tensor->ne[1];
-        if (nx % QK_K != 0) {
-            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
-            convert_incompatible_tensor = true;
-        } else {
-            ++qs.n_k_quantized;
-        }
-    }
-    if (convert_incompatible_tensor) {
-        switch (new_type) {
-            case GGML_TYPE_TQ1_0:
-            case GGML_TYPE_TQ2_0:  new_type = GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
-            case GGML_TYPE_IQ2_XXS:
-            case GGML_TYPE_IQ2_XS:
-            case GGML_TYPE_IQ2_S:
-            case GGML_TYPE_IQ3_XXS:
-            case GGML_TYPE_IQ3_S:
-            case GGML_TYPE_IQ1_S:
-            case GGML_TYPE_IQ1_M:
-            case GGML_TYPE_Q2_K:
-            case GGML_TYPE_Q3_K:
-            case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
-            case GGML_TYPE_Q4_K:   new_type = GGML_TYPE_Q5_0;   break;
-            case GGML_TYPE_Q5_K:   new_type = GGML_TYPE_Q5_1;   break;
-            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
-            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
-        }
-        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
-            new_type = GGML_TYPE_F16;
-        }
-        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
-        ++qs.n_fallback;
-    }
-
-    return new_type;
-}
-
-static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
-    if (nthread < 2) {
-        // single-thread
-        size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
-        if (!ggml_validate_row_data(new_type, new_data, new_size)) {
-            throw std::runtime_error("quantized data validation failed");
-        }
-        return new_size;
-    }
-
-    std::mutex mutex;
-    int64_t counter = 0;
-    size_t new_size = 0;
-    bool valid = true;
-    auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
-            nrows, n_per_row, imatrix]() {
-        const int64_t nrows_per_chunk = chunk_size / n_per_row;
-        size_t local_size = 0;
-        while (true) {
-            std::unique_lock lock(mutex);
-            int64_t first_row = counter; counter += nrows_per_chunk;
-            if (first_row >= nrows) {
-                if (local_size > 0) {
-                    new_size += local_size;
-                }
-                break;
-            }
-            lock.unlock();
-            const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
-            size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
-            local_size += this_size;
-
-            // validate the quantized data
-            const size_t row_size  = ggml_row_size(new_type, n_per_row);
-            void * this_data = (char *) new_data + first_row * row_size;
-            if (!ggml_validate_row_data(new_type, this_data, this_size)) {
-                std::unique_lock lock(mutex);
-                valid = false;
-                break;
-            }
-        }
-    };
-    for (int it = 0; it < nthread - 1; ++it) {
-        workers.emplace_back(compute);
-    }
-    compute();
-    for (auto & w : workers) { w.join(); }
-    workers.clear();
-    if (!valid) {
-        throw std::runtime_error("quantized data validation failed");
-    }
-    return new_size;
-}
-
-static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
-    ggml_type default_type;
-    llama_ftype ftype = params->ftype;
-
-    switch (params->ftype) {
-        case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
-        case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
-        case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
-        case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
-        case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
-        case LLAMA_FTYPE_MOSTLY_F16:  default_type = GGML_TYPE_F16;  break;
-        case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
-        case LLAMA_FTYPE_ALL_F32:     default_type = GGML_TYPE_F32;  break;
-
-        // K-quants
-        case LLAMA_FTYPE_MOSTLY_Q2_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q2_K:    default_type = GGML_TYPE_Q2_K;    break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_XS:  default_type = GGML_TYPE_IQ3_S;   break;
-        case LLAMA_FTYPE_MOSTLY_Q3_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q3_K_M:
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L:  default_type = GGML_TYPE_Q3_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q4_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M:  default_type = GGML_TYPE_Q4_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q5_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = GGML_TYPE_Q5_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = GGML_TYPE_Q6_K;    break;
-        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = GGML_TYPE_TQ1_0;   break;
-        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = GGML_TYPE_TQ2_0;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = GGML_TYPE_IQ2_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = GGML_TYPE_IQ2_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_M:   default_type = GGML_TYPE_IQ2_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
-        case LLAMA_FTYPE_MOSTLY_IQ1_S:   default_type = GGML_TYPE_IQ1_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ1_M:   default_type = GGML_TYPE_IQ1_M;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ4_NL:  default_type = GGML_TYPE_IQ4_NL;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
-
-        default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
-    }
-
-    int nthread = params->nthread;
-
-    if (nthread <= 0) {
-        nthread = std::thread::hardware_concurrency();
-    }
-
-    // mmap consistently increases speed Linux, and also increases speed on Windows with
-    // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
-#if defined(__linux__) || defined(_WIN32)
-    constexpr bool use_mmap = true;
-#else
-    constexpr bool use_mmap = false;
-#endif
-
-    llama_model_kv_override * kv_overrides = nullptr;
-    if (params->kv_overrides) {
-        auto v = (std::vector*)params->kv_overrides;
-        kv_overrides = v->data();
-    }
-    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
-    ml.init_mappings(false); // no prefetching
-
-    llama_model model;
-    llm_load_arch(ml, model);
-    llm_load_hparams(ml, model);
-    llm_load_stats(ml, model);
-
-    struct quantize_state_internal qs(model, params);
-
-    if (params->only_copy) {
-        ftype = model.ftype;
-    }
-    const std::unordered_map> * imatrix_data = nullptr;
-    if (params->imatrix) {
-        imatrix_data = static_cast>*>(params->imatrix);
-        if (imatrix_data) {
-            LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
-            qs.has_imatrix = true;
-            // check imatrix for nans or infs
-            for (const auto & kv : *imatrix_data) {
-                for (float f : kv.second) {
-                    if (!std::isfinite(f)) {
-                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
-                    }
-                }
-            }
-        }
-    }
-
-    const size_t align = GGUF_DEFAULT_ALIGNMENT;
-    gguf_context_ptr ctx_out { gguf_init_empty() };
-
-    // copy the KV pairs from the input file
-    gguf_set_kv     (ctx_out.get(), ml.meta.get());
-    gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
-    gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
-
-    // Remove split metadata
-    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
-    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
-    gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
-
-    if (params->kv_overrides) {
-        const std::vector & overrides = *(const std::vector *)params->kv_overrides;
-        for (const auto & o : overrides) {
-            if (o.key[0] == 0) break;
-            if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
-                gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
-                gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
-                gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
-                gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
-            } else {
-                LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
-            }
-        }
-    }
-
-    // make a list of weights
-    std::vector tensors;
-    tensors.reserve(ml.weights_map.size());
-    for (const auto & it : ml.weights_map) {
-        tensors.push_back(&it.second);
-    }
-
-    // keep_split requires that the weights are sorted by split index
-    if (params->keep_split) {
-        std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
-            if (a->idx == b->idx) {
-                return a->offs < b->offs;
-            }
-            return a->idx < b->idx;
-        });
-    }
-
-    for (const auto * it : tensors) {
-        const struct ggml_tensor * tensor = it->tensor;
-
-        const std::string name = ggml_get_name(tensor);
-
-        // TODO: avoid hardcoded tensor names - use the TN_* constants
-        if (name.find("attn_v.weight")   != std::string::npos ||
-            name.find("attn_qkv.weight") != std::string::npos ||
-            name.find("attn_kv_b.weight")!= std::string::npos) {
-            ++qs.n_attention_wv;
-        } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
-            qs.has_output = true;
-        }
-    }
-
-    qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
-
-    // sanity checks
-    {
-        const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
-        // attention layers have a non-zero number of kv heads
-        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
-        if (llama_model_has_encoder(&model)) {
-            n_attn_layer *= 3;
-        }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
-    }
-
-    size_t total_size_org = 0;
-    size_t total_size_new = 0;
-
-    std::vector workers;
-    workers.reserve(nthread);
-
-    int idx = 0;
-
-    std::vector> read_data;
-    std::vector> work;
-    std::vector> f32_conv_buf;
-
-    uint16_t n_split = 1;
-
-    // Assume split index is continuous
-    if (params->keep_split) {
-        for (const auto * it : tensors) {
-            n_split = std::max(uint16_t(it->idx + 1), n_split);
-        }
-    }
-    std::vector ctx_outs(n_split);
-    ctx_outs[0] = std::move(ctx_out);
-
-    // populate the original tensors so we get an initial meta data
-    for (const auto * it : tensors) {
-        uint16_t i_split = params->keep_split ? it->idx : 0;
-        struct ggml_tensor * tensor = it->tensor;
-        if (!ctx_outs[i_split]) {
-            ctx_outs[i_split].reset(gguf_init_empty());
-        }
-        gguf_add_tensor(ctx_outs[i_split].get(), tensor);
-    }
-
-    // Set split info if needed
-    if (n_split > 1) {
-        for (size_t i = 0; i < ctx_outs.size(); ++i) {
-            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
-            gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
-            gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
-        }
-    }
-
-    int cur_split = -1;
-    std::ofstream fout;
-    auto close_ofstream = [&]() {
-        // Write metadata and close file handler
-        if (fout.is_open()) {
-            fout.seekp(0);
-            std::vector data(gguf_get_meta_size(ctx_outs[cur_split].get()));
-            gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
-            fout.write((const char *) data.data(), data.size());
-            fout.close();
-        }
-    };
-    auto new_ofstream = [&](int index) {
-        cur_split = index;
-        GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
-        std::string fname = fname_out;
-        if (params->keep_split) {
-            char split_path[PATH_MAX] = {0};
-            llama_split_path(split_path, sizeof(split_path), fname_out.c_str(), cur_split, n_split);
-            fname = std::string(split_path);
-        }
-
-        fout = std::ofstream(fname, std::ios::binary);
-        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
-        const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
-        // placeholder for the meta data
-        ::zeros(fout, meta_size);
-    };
-
-    const auto tn = LLM_TN(model.arch);
-    new_ofstream(0);
-    for (const auto * it : tensors) {
-        const auto & weight = *it;
-        struct ggml_tensor * tensor = weight.tensor;
-        if (weight.idx != cur_split && params->keep_split) {
-            close_ofstream();
-            new_ofstream(weight.idx);
-        }
-
-        const std::string name = ggml_get_name(tensor);
-
-        if (!ml.use_mmap) {
-            if (read_data.size() < ggml_nbytes(tensor)) {
-                read_data.resize(ggml_nbytes(tensor));
-            }
-            tensor->data = read_data.data();
-        }
-        ml.load_data_for(tensor);
-
-        LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
-               ++idx, ml.n_tensors,
-               ggml_get_name(tensor),
-               llama_format_tensor_shape(tensor).c_str(),
-               ggml_type_name(tensor->type));
-
-        // This used to be a regex, but  has an extreme cost to compile times.
-        bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
-
-        // quantize only 2D and 3D tensors (experts)
-        quantize &= (ggml_n_dims(tensor) >= 2);
-
-        // do not quantize norm tensors
-        quantize &= name.find("_norm.weight") == std::string::npos;
-
-        quantize &= params->quantize_output_tensor || name != "output.weight";
-        quantize &= !params->only_copy;
-
-        // do not quantize expert gating tensors
-        // NOTE: can't use LLM_TN here because the layer number is not known
-        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
-
-        // do not quantize positional embeddings and token types (BERT)
-        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
-        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
-
-        // do not quantize Mamba's small yet 2D weights
-        // NOTE: can't use LLM_TN here because the layer number is not known
-        quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
-
-        // do not quantize RWKV's time_mix_first tensors
-        quantize &= name.find("time_mix_first.weight") == std::string::npos;
-        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
-        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
-        quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
-        quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
-
-        // do not quantize relative position bias (T5)
-        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
-
-        enum ggml_type new_type;
-        void * new_data;
-        size_t new_size;
-
-        if (quantize) {
-            new_type = default_type;
-
-            // get more optimal quantization type based on the tensor shape, layer, etc.
-            if (!params->pure && ggml_is_quantized(default_type)) {
-                new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
-            }
-            if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
-                new_type = params->token_embedding_type;
-            }
-            if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
-                new_type = params->output_tensor_type;
-            }
-
-            // If we've decided to quantize to the same type the tensor is already
-            // in then there's nothing to do.
-            quantize = tensor->type != new_type;
-        }
-
-        if (!quantize) {
-            new_type = tensor->type;
-            new_data = tensor->data;
-            new_size = ggml_nbytes(tensor);
-            LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
-        } else {
-            const int64_t nelements = ggml_nelements(tensor);
-
-            const float * imatrix = nullptr;
-            if (imatrix_data) {
-                auto it = imatrix_data->find(tensor->name);
-                if (it == imatrix_data->end()) {
-                    LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
-                } else {
-                    if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
-                        imatrix = it->second.data();
-                    } else {
-                        LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
-                                int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
-
-                        // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
-                        // this is a significant error and it may be good idea to abort the process if this happens,
-                        // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
-                        // tok_embd should be ignored in this case, since it always causes this warning
-                        if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
-                            throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
-                                    int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
-                        }
-                    }
-                }
-            }
-            if ((new_type == GGML_TYPE_IQ2_XXS ||
-                 new_type == GGML_TYPE_IQ2_XS  ||
-                 new_type == GGML_TYPE_IQ2_S   ||
-                 new_type == GGML_TYPE_IQ1_S   ||
-                (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight"))  ||
-                (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
-                LLAMA_LOG_ERROR("\n\n============================================================\n");
-                LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
-                LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
-                LLAMA_LOG_ERROR("============================================================\n\n");
-                throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
-            }
-
-            float * f32_data;
-
-            if (tensor->type == GGML_TYPE_F32) {
-                f32_data = (float *) tensor->data;
-            } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
-                throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
-            } else {
-                llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
-                f32_data = (float *) f32_conv_buf.data();
-            }
-
-            LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
-            fflush(stdout);
-
-            if (work.size() < (size_t)nelements * 4) {
-                work.resize(nelements * 4); // upper bound on size
-            }
-            new_data = work.data();
-
-            const int64_t n_per_row = tensor->ne[0];
-            const int64_t nrows = tensor->ne[1];
-
-            static const int64_t min_chunk_size = 32 * 512;
-            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
-
-            const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
-            const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
-            const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
-
-            // quantize each expert separately since they have different importance matrices
-            new_size = 0;
-            for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
-                const float * f32_data_03 = f32_data + i03 * nelements_matrix;
-                void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
-                const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
-
-                new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
-            }
-            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
-        }
-        total_size_org += ggml_nbytes(tensor);
-        total_size_new += new_size;
-
-        // update the gguf meta data as we go
-        gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
-
-        // write tensor data + padding
-        fout.write((const char *) new_data, new_size);
-        zeros(fout, GGML_PAD(new_size, align) - new_size);
-    }
-    close_ofstream();
-
-    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
-    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
-
-    if (qs.n_fallback > 0) {
-        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
-                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
-    }
-}
-
-static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
-    LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
-
-    ggml_context * ctx_init;
-    struct gguf_init_params meta_gguf_params = {
-        /* .no_alloc = */ true,
-        /* .ctx      = */ &ctx_init,
-    };
-
-    gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) };
-    if (!ctx_gguf) {
-        throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
-    }
-
-    ggml_context_ptr ctx { ctx_init };
-
-    // check metadata
-    {
-        auto get_kv_str = [&](const std::string & key) -> std::string {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
-        };
-        auto get_kv_f32 = [&](const std::string & key) -> float {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
-        };
-        LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
-
-        auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
-        if (general_type != "adapter") {
-            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
-        }
-
-        auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
-        auto general_arch = llm_arch_from_string(general_arch_str);
-        if (general_arch != model->arch) {
-            throw std::runtime_error("model arch and LoRA arch mismatch");
-        }
-
-        auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
-        if (adapter_type != "lora") {
-            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
-        }
-
-        adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
-    }
-
-    int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
-
-    // contexts for each buffer type
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            // add a new context
-            struct ggml_init_params params = {
-                /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * buft_ctx = ggml_init(params);
-            if (!buft_ctx) {
-                return nullptr;
-            }
-            ctx_map[buft] = buft_ctx;
-            adapter.ctxs.emplace_back(buft_ctx);
-            return buft_ctx;
-        };
-        return it->second;
-    };
-
-    // bundle lora_a and lora_b into pairs
-    std::map ab_map;
-    auto str_endswith = [](const std::string & str, const std::string & suffix) {
-        return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
-    };
-    for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) {
-        std::string name(cur->name);
-        if (str_endswith(name, ".lora_a")) {
-            replace_all(name, ".lora_a", "");
-            if (ab_map.find(name) == ab_map.end()) {
-                ab_map[name] = llama_lora_weight(cur, nullptr);
-            } else {
-                ab_map[name].a = cur;
-            }
-        } else if (str_endswith(name, ".lora_b")) {
-            replace_all(name, ".lora_b", "");
-            if (ab_map.find(name) == ab_map.end()) {
-                ab_map[name] = llama_lora_weight(nullptr, cur);
-            } else {
-                ab_map[name].b = cur;
-            }
-        } else {
-            throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
-        }
-    }
-
-    // add tensors
-    for (auto & it : ab_map) {
-        const std::string & name = it.first;
-        llama_lora_weight & w = it.second;
-
-        if (!w.a || !w.b) {
-            throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
-        }
-
-        // device buft and device ctx
-        auto * model_tensor = llama_get_model_tensor(model, name.c_str());
-        if (!model_tensor) {
-            throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
-        }
-        struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
-        // validate tensor shape
-        if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
-            throw std::runtime_error("tensor '" + name + "' has incorrect shape");
-        }
-        if (w.a->ne[1] != w.b->ne[0]) {
-            throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
-        }
-        // save tensor to adapter
-        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
-        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
-        ggml_set_name(tensor_a, w.a->name);
-        ggml_set_name(tensor_b, w.b->name);
-        adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
-    }
-
-    // allocate tensors / buffers and zero
-    {
-        adapter.ctxs.reserve(ctx_map.size());
-        adapter.bufs.reserve(ctx_map.size());
-        for (auto & it : ctx_map) {
-            ggml_backend_buffer_type_t buft = it.first;
-            ggml_context * ctx_dev = it.second;
-            ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) };
-            if (!buf) {
-                throw std::runtime_error("failed to allocate buffer for lora adapter\n");
-            }
-            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0);
-            adapter.bufs.emplace_back(std::move(buf));
-        }
-    }
-
-    // set tensor data
-    {
-        llama_file gguf_file(path_lora, "rb");
-        std::vector read_buf;
-        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
-            size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
-            size_t size = ggml_nbytes(orig);
-            read_buf.resize(size);
-            gguf_file.seek(offs, SEEK_SET);
-            gguf_file.read_raw(read_buf.data(), size);
-            ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
-        };
-        for (auto & it : adapter.ab_map) {
-            auto orig = ab_map[it.first];
-            auto dev  = it.second;
-            set_tensor(orig.a, dev.a);
-            set_tensor(orig.b, dev.b);
-        }
-    }
-
-    LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
-}
-
 int32_t llama_lora_adapter_set(
             struct llama_context * ctx,
             struct llama_lora_adapter * adapter,
             float scale) {
-    if (ctx->cparams.flash_attn) {
-        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
-        return -1;
-    }
     ctx->lora_adapters[adapter] = scale;
     return 0;
 }
@@ -20357,6 +11599,7 @@ int32_t llama_lora_adapter_remove(
         ctx->lora_adapters.erase(pos);
         return 0;
     }
+
     return -1;
 }
 
@@ -20364,37 +11607,20 @@ void llama_lora_adapter_clear(struct llama_context * ctx) {
     ctx->lora_adapters.clear();
 }
 
-void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
-    delete adapter;
+// TODO: tmp
+int32_t llama_control_vector_apply(
+        struct llama_context * lctx,
+                 const float * data,
+                      size_t   len,
+                     int32_t   n_embd,
+                     int32_t   il_start,
+                     int32_t   il_end) {
+    return llama_control_vector_apply(lctx->cvec, lctx->model, data, len, n_embd, il_start, il_end);
 }
 
 //
 // interface implementation
 //
-struct llama_model_params llama_model_default_params() {
-    struct llama_model_params result = {
-        /*.devices                     =*/ nullptr,
-        /*.n_gpu_layers                =*/ 0,
-        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
-        /*.main_gpu                    =*/ 0,
-        /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
-        /*.progress_callback           =*/ nullptr,
-        /*.progress_callback_user_data =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-        /*.vocab_only                  =*/ false,
-        /*.use_mmap                    =*/ true,
-        /*.use_mlock                   =*/ false,
-        /*.check_tensors               =*/ false,
-    };
-
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
-
-    return result;
-}
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
@@ -20439,24 +11665,6 @@ struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     return result;
 }
 
-struct llama_model_quantize_params llama_model_quantize_default_params() {
-    struct llama_model_quantize_params result = {
-        /*.nthread                     =*/ 0,
-        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
-        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
-        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
-        /*.allow_requantize            =*/ false,
-        /*.quantize_output_tensor      =*/ true,
-        /*.only_copy                   =*/ false,
-        /*.pure                        =*/ false,
-        /*.keep_split                  =*/ false,
-        /*.imatrix                     =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-    };
-
-    return result;
-}
-
 size_t llama_max_devices(void) {
     return 16;
 }
@@ -20499,19 +11707,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
     }
 }
 
-void llama_attach_threadpool(
-             struct llama_context * ctx,
-        ggml_threadpool_t   threadpool,
-        ggml_threadpool_t   threadpool_batch) {
-    ctx->threadpool       = threadpool;
-    ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
-}
-
-void llama_detach_threadpool(struct llama_context * ctx) {
-    ctx->threadpool       = nullptr;
-    ctx->threadpool_batch = nullptr;
-}
-
 void llama_backend_free(void) {
     ggml_quantize_free();
 }
@@ -20522,7 +11717,13 @@ int64_t llama_time_us(void) {
 
 struct llama_model * llama_load_model_from_file(
         const char * path_model,
-        struct llama_model_params   params) {
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
+}
+
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
     ggml_time_init();
 
     llama_model * model = new llama_model;
@@ -20561,7 +11762,7 @@ struct llama_model * llama_load_model_from_file(
         ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
         if (!rpc_reg) {
             LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
-            llama_free_model(model);
+            llama_model_free(model);
             return nullptr;
         }
 
@@ -20569,7 +11770,7 @@ struct llama_model * llama_load_model_from_file(
         ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
         if (!ggml_backend_rpc_add_device_fn) {
             LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
-            llama_free_model(model);
+            llama_model_free(model);
             return nullptr;
         }
 
@@ -20579,7 +11780,7 @@ struct llama_model * llama_load_model_from_file(
                 model->devices.push_back(dev);
             } else {
                 LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
-                llama_free_model(model);
+                llama_model_free(model);
                 return nullptr;
             }
         }
@@ -20611,7 +11812,7 @@ struct llama_model * llama_load_model_from_file(
     if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
         if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
             LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
-            llama_free_model(model);
+            llama_model_free(model);
             return nullptr;
         }
         ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
@@ -20633,17 +11834,14 @@ struct llama_model * llama_load_model_from_file(
         } else if (status == -2) {
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
-        llama_free_model(model);
+
+        llama_model_free(model);
         return nullptr;
     }
 
     return model;
 }
 
-void llama_free_model(struct llama_model * model) {
-    delete model;
-}
-
 struct llama_context * llama_new_context_with_model(
                  struct llama_model * model,
         struct llama_context_params   params) {
@@ -20844,7 +12042,7 @@ struct llama_context * llama_new_context_with_model(
 
         llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
 
-        if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
+        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
@@ -20995,442 +12193,26 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
-void llama_free(struct llama_context * ctx) {
-    delete ctx;
-}
+//
+// kv cache
+//
 
-uint32_t llama_n_ctx(const struct llama_context * ctx) {
-    return ctx->cparams.n_ctx;
-}
-
-uint32_t llama_n_batch(const struct llama_context * ctx) {
-    return ctx->cparams.n_batch;
-}
-
-uint32_t llama_n_ubatch(const struct llama_context * ctx) {
-    return ctx->cparams.n_ubatch;
-}
-
-uint32_t llama_n_seq_max(const struct llama_context * ctx) {
-    return ctx->kv_self.size;
-}
-
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
-int32_t llama_n_head(const struct llama_model * model) {
-    return model->hparams.n_head();
-}
-
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-    switch (model->arch) {
-        // these models do not use RoPE
-        case LLM_ARCH_GPT2:
-        case LLM_ARCH_GPTJ:
-        case LLM_ARCH_MPT:
-        case LLM_ARCH_REFACT:
-        case LLM_ARCH_BLOOM:
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_T5:
-        case LLM_ARCH_T5ENCODER:
-        case LLM_ARCH_JAIS:
-        case LLM_ARCH_RWKV6:
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            return LLAMA_ROPE_TYPE_NONE;
-
-        // use what we call a normal RoPE, operating on pairs of consecutive head values
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_DECI:
-        case LLM_ARCH_BAICHUAN:
-        case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_ORION:
-        case LLM_ARCH_INTERNLM2:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_XVERSE:
-        case LLM_ARCH_COMMAND_R:
-        case LLM_ARCH_OLMO:
-        case LLM_ARCH_ARCTIC:
-        case LLM_ARCH_DEEPSEEK:
-        case LLM_ARCH_DEEPSEEK2:
-        case LLM_ARCH_CHATGLM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-        case LLM_ARCH_CHAMELEON:
-            return LLAMA_ROPE_TYPE_NORM;
-
-        // the pairs of head values are offset by n_rot/2
-        case LLM_ARCH_FALCON:
-        case LLM_ARCH_GROK:
-        case LLM_ARCH_DBRX:
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_NOMIC_BERT:
-        case LLM_ARCH_STABLELM:
-        case LLM_ARCH_BITNET:
-        case LLM_ARCH_QWEN:
-        case LLM_ARCH_QWEN2:
-        case LLM_ARCH_QWEN2MOE:
-        case LLM_ARCH_OLMO2:
-        case LLM_ARCH_OLMOE:
-        case LLM_ARCH_PHI2:
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_GEMMA:
-        case LLM_ARCH_GEMMA2:
-        case LLM_ARCH_STARCODER2:
-        case LLM_ARCH_OPENELM:
-        case LLM_ARCH_GPTNEOX:
-        case LLM_ARCH_CODESHELL:
-        case LLM_ARCH_NEMOTRON:
-        case LLM_ARCH_EXAONE:
-        case LLM_ARCH_MINICPM3:
-            return LLAMA_ROPE_TYPE_NEOX;
-
-        case LLM_ARCH_QWEN2VL:
-            return LLAMA_ROPE_TYPE_MROPE;
-
-        // all model arches should be listed explicitly here
-        case LLM_ARCH_UNKNOWN:
-            GGML_ABORT("unknown architecture");
-    }
-
-    return LLAMA_ROPE_TYPE_NONE;
-}
-
-float llama_rope_freq_scale_train(const struct llama_model * model) {
-    return model->hparams.rope_freq_scale_train;
-}
-
-int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_meta_count(const struct llama_model * model) {
-    return (int)model->gguf_kv.size();
-}
-
-int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->first.c_str());
-}
-
-int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name(model->arch),
-            llama_model_type_name(model->type),
-            llama_model_ftype_name(model->ftype).c_str());
-}
-
-uint64_t llama_model_size(const struct llama_model * model) {
-    return model->n_bytes;
-}
-
-uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements;
-}
-
-bool llama_model_has_encoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5:        return true;
-        case LLM_ARCH_T5ENCODER: return true;
-        default:                 return false;
-    }
-}
-
-bool llama_model_has_decoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5ENCODER: return false;
-        default:                 return true;
-    }
-}
-
-llama_token llama_model_decoder_start_token(const struct llama_model * model) {
-    return model->hparams.dec_start_token_id;
-}
-
-bool llama_model_is_recurrent(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_MAMBA:  return true;
-        case LLM_ARCH_RWKV6:  return true;
-        default:              return false;
-    }
-}
-
-uint32_t llama_model_quantize(
-        const char * fname_inp,
-        const char * fname_out,
-        const llama_model_quantize_params * params) {
-    try {
-        llama_model_quantize_internal(fname_inp, fname_out, params);
-        return 0;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
-        return 1;
-    }
-}
-
-struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
-    try {
-        struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
-        llama_lora_adapter_init_internal(model, path_lora, *adapter);
-        return adapter;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
-        return nullptr;
-    }
-}
-
-static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
-    GGML_ASSERT(cvec.tensors.empty());
-    GGML_ASSERT(cvec.ctxs.empty());
-    GGML_ASSERT(cvec.bufs.empty());
-
-    // create a context for each buffer type
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            struct ggml_init_params params = {
-                /*.mem_size   =*/ model.hparams.n_layer*ggml_tensor_overhead(),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                return nullptr;
-            }
-            ctx_map[buft] = ctx;
-            cvec.ctxs.emplace_back(ctx);
-            return ctx;
-        }
-        return it->second;
-    };
-
-    // make tensors
-    cvec.tensors.reserve(model.hparams.n_layer);
-    cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        ggml_backend_buffer_type_t buft = select_buft(*model.dev_layer.at(il).buft_list,
-            [&](ggml_context * ctx) {
-                ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
-                ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
-                return ggml_add(ctx, cur, layer_dir);
-            });
-        ggml_context * ctx = ctx_for_buft(buft);
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
-            return false;
-        }
-        ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
-        cvec.tensors.push_back(tensor);
-    }
-
-    // allocate tensors / buffers and zero
-    cvec.bufs.reserve(ctx_map.size());
-    for (auto it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx = it.second;
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
-            return false;
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        cvec.bufs.emplace_back(buf);
-    }
-
-    return true;
-}
-
-int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
-    const llama_model & model = lctx->model;
-    llama_control_vector & cvec = lctx->cvec;
-
-    if (data == nullptr) {
-        // disable the current control vector (but leave allocated for later)
-        cvec.layer_start = -1;
-        cvec.layer_end   = -1;
-        return 0;
-    }
-
-    if (n_embd != (int) model.hparams.n_embd) {
-        LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
-        return 1;
-    }
-
-    if (cvec.tensors.empty()) {
-        if (!llama_control_vector_init(cvec, model)) {
-            return 1;
-        }
-    }
-
-    cvec.layer_start = il_start;
-    cvec.layer_end   = il_end;
-
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        assert(cvec.tensors[il] != nullptr);
-
-        const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
-        if (off + n_embd <= len) {
-            ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
-        }
-    }
-
-    return 0;
-}
+// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API
 
 struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
-    struct llama_kv_cache_view result = {
-        /*.n_cells            = */ 0,
-        /*.n_seq_max          = */ n_seq_max,
-        /*.token_count        = */ 0,
-        /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
-        /*.max_contiguous     = */ 0,
-        /*.max_contiguous_idx = */ -1,
-        /*.cells              = */ nullptr,
-        /*.cells_sequences    = */ nullptr,
-    };
-    return result;
-}
-
-void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
-    if (view->cells != nullptr) {
-        free(view->cells);
-        view->cells = nullptr;
-    }
-    if (view->cells_sequences != nullptr) {
-        free(view->cells_sequences);
-        view->cells_sequences = nullptr;
-    }
+    return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
 }
 
 void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
-    if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
-        view->n_cells = int32_t(ctx->kv_self.size);
-        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
-        view->cells = (struct llama_kv_cache_view_cell *)p;
-        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
-        view->cells_sequences = (llama_seq_id *)p;
-    }
-
-    const std::vector & kv_cells = ctx->kv_self.cells;
-    llama_kv_cache_view_cell * c_curr = view->cells;
-    llama_seq_id * cs_curr = view->cells_sequences;
-    int32_t used_cells = 0;
-    int32_t token_count = 0;
-    int32_t curr_contig_idx = -1;
-    uint32_t max_contig = 0;
-    int32_t max_contig_idx = -1;
-
-    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
-        const size_t curr_size = kv_cells[i].seq_id.size();
-        token_count += curr_size;
-        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
-
-        if (curr_size > 0) {
-            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
-                max_contig = i - curr_contig_idx;
-                max_contig_idx = curr_contig_idx;
-            }
-            curr_contig_idx = -1;
-        } else if (curr_contig_idx < 0) {
-            curr_contig_idx = i;
-        }
-
-        int seq_idx = 0;
-        for (const llama_seq_id it : kv_cells[i].seq_id) {
-            if (seq_idx >= view->n_seq_max) {
-                break;
-            }
-            cs_curr[seq_idx] = it;
-            seq_idx++;
-        }
-        if (seq_idx != 0) {
-            used_cells++;
-        }
-        for (; seq_idx < view->n_seq_max; seq_idx++) {
-            cs_curr[seq_idx] = -1;
-        }
-    }
-    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
-        max_contig_idx = curr_contig_idx;
-        max_contig = kv_cells.size() - curr_contig_idx;
-    }
-    view->max_contiguous = max_contig;
-    view->max_contiguous_idx = max_contig_idx;
-    view->token_count = token_count;
-    view->used_cells = used_cells;
-    if (uint32_t(used_cells) != ctx->kv_self.used) {
-        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
-            __func__, ctx->kv_self.used, used_cells);
-    }
+    llama_kv_cache_view_update(view, ctx->kv_self);
 }
 
 int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    int result = 0;
-
-    for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
-        result += ctx->kv_self.cells[i].seq_id.size();
-    }
-
-    return result;
+    return llama_get_kv_cache_token_count(ctx->kv_self);
 }
 
 int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
-    return ctx->kv_self.used;
+    return llama_get_kv_cache_used_cells(ctx->kv_self);
 }
 
 void llama_kv_cache_clear(struct llama_context * ctx) {
@@ -21477,1077 +12259,19 @@ void llama_kv_cache_defrag(struct llama_context * ctx) {
 }
 
 void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_internal(*ctx);
+    llama_kv_cache_update_impl(*ctx);
 }
 
 bool llama_kv_cache_can_shift(struct llama_context * ctx) {
-    return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+    return llama_kv_cache_can_shift(ctx->kv_self);
 }
 
-// deprecated
-size_t llama_get_state_size(struct llama_context * ctx) {
-    return llama_state_get_size(ctx);
-}
-
-// deprecated
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst, -1);
-}
-
-// deprecated
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src, -1);
-}
-
-// deprecated
-bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-}
-
-// deprecated
-bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
-}
-
-// TODO: replace all non-fatal assertions with returned errors or exceptions
-struct llama_data_write {
-    virtual void write(const void * src, size_t size) = 0;
-    virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
-    virtual size_t get_size_written() = 0;
-    virtual ~llama_data_write() = default;
-
-    void write_string(const std::string & str) {
-        uint32_t str_size = str.size();
-
-        write(&str_size,  sizeof(str_size));
-        write(str.data(), str_size);
-    }
-
-    void write_model_info(const struct llama_context * ctx) {
-        std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
-        write_string(arch_str);
-        // TODO: add more model-specific info which should prevent loading the session file if not identical
-    }
-
-    //void write_rng(const std::mt19937 & rng) {
-    //    std::ostringstream rng_ss;
-    //    rng_ss << rng;
-
-    //    const std::string & rng_str = rng_ss.str();
-
-    //    write_string(rng_str);
-    //}
-
-    void write_output_ids(struct llama_context * ctx) {
-        llama_output_reorder(ctx);
-
-        const uint32_t n_outputs = ctx->n_outputs;
-
-        std::vector output_pos;
-
-        const size_t    n_batch = ctx->cparams.n_batch;
-        const auto & output_ids = ctx->output_ids;
-
-        GGML_ASSERT(n_outputs <= ctx->output_size);
-
-        output_pos.resize(n_outputs);
-
-        // build a more compact representation of the output ids
-        for (size_t i = 0; i < n_batch; ++i) {
-            // map an output id to a position in the batch
-            int32_t pos = output_ids[i];
-            if (pos >= 0) {
-                GGML_ASSERT((uint32_t) pos < n_outputs);
-                output_pos[pos] = i;
-            }
-        }
-
-        write(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs) {
-            write(output_pos.data(), n_outputs * sizeof(int32_t));
-        }
-    }
-
-    void write_logits(const struct llama_context * ctx) {
-        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
-
-        write(&logits_size, sizeof(logits_size));
-
-        if (logits_size) {
-            write(ctx->logits, logits_size * sizeof(float));
-        }
-    }
-
-    void write_embeddings(const struct llama_context * ctx) {
-        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
-
-        write(&embeddings_size, sizeof(embeddings_size));
-
-        if (embeddings_size) {
-            write(ctx->embd, embeddings_size * sizeof(float));
-        }
-    }
-
-    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) {
-
-        for (const auto & range : cell_ranges) {
-            for (uint32_t i = range.first; i < range.second; ++i) {
-                const auto & cell = kv_self.cells[i];
-                const llama_pos pos      = cell.pos;
-                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
-
-                write(&pos,      sizeof(pos));
-                write(&n_seq_id, sizeof(n_seq_id));
-
-                if (n_seq_id) {
-                    for (auto seq_id : cell.seq_id) {
-                        write(&seq_id, sizeof(seq_id));
-                    }
-                }
-            }
-        }
-    }
-
-    void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        const struct llama_hparams & hparams = ctx->model.hparams;
-
-        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
-        const uint32_t n_layer = hparams.n_layer;
-
-        write(&v_trans, sizeof(v_trans));
-        write(&n_layer, sizeof(n_layer));
-
-        std::vector tmp_buf;
-
-        // Iterate and write all the keys first, each row is a cell
-        // Get whole range at a time
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-            // Write key type
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            write(&k_type_i, sizeof(k_type_i));
-
-            // Write row size of key
-            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            write(&k_size_row, sizeof(k_size_row));
-
-            // Read each range of cells of k_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * k_size_row;
-                write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
-            }
-        }
-
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
-
-                // Write row size of value
-                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                write(&v_size_row, sizeof(v_size_row));
-
-                // Read each range of cells of v_size length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t buf_size = range_size * v_size_row;
-                    write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
-                }
-            }
-        } else {
-            // When v is transposed, we also need the element size and get the element ranges from each row
-            const uint32_t kv_size = kv_self.size;
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
-
-                // Write element size
-                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                write(&v_size_el, sizeof(v_size_el));
-
-                // Write GQA embedding size
-                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-                // For each row, we get the element values of each cell
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    // Read each range of cells of v_size_el length each into tmp_buf and write out
-                    for (const auto & range : cell_ranges) {
-                        const size_t range_size = range.second - range.first;
-                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                        const size_t buf_size = range_size * v_size_el;
-                        write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
-                    }
-                }
-            }
-        }
-    }
-
-    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        std::vector> cell_ranges; // ranges, from inclusive, to exclusive
-        uint32_t cell_count = 0;
-
-        // Count the number of cells with the specified seq_id
-        // Find all the ranges of cells with this seq id (or all, when -1)
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            } else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-
-        write(&cell_count, sizeof(cell_count));
-
-        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
-        write_kv_cache_data(ctx, cell_ranges);
-    }
-};
-
-struct llama_data_read {
-    virtual const uint8_t * read(size_t size) = 0;
-    virtual void read_to(void * dst, size_t size) = 0;
-    virtual size_t get_size_read() = 0;
-    virtual ~llama_data_read() = default;
-
-    void read_string(std::string & str) {
-        uint32_t str_size;
-        read_to(&str_size, sizeof(str_size));
-
-        str.assign((const char *) read(str_size), str_size);
-    }
-
-    // validate model information
-    void read_model_info(const struct llama_context * ctx) {
-        std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
-        std::string arch_str;
-        read_string(arch_str);
-        if (cur_arch_str != arch_str) {
-            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
-        }
-        // TODO: add more info which needs to be identical but which is not verified otherwise
-    }
-
-    //void read_rng(std::mt19937 & rng) {
-    //    std::string rng_str;
-    //    read_string(rng_str);
-
-    //    std::istringstream rng_ss(rng_str);
-    //    rng_ss >> rng;
-
-    //    if (rng_ss.fail()) {
-    //        throw std::runtime_error("failed to load RNG state");
-    //    }
-    //}
-
-    void read_output_ids(struct llama_context * ctx) {
-        std::vector output_pos;
-
-        uint32_t n_outputs;
-        read_to(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
-            throw std::runtime_error("could not reserve outputs");
-        }
-
-        if (n_outputs) {
-            output_pos.resize(n_outputs);
-            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
-
-            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
-                int32_t id = output_pos[i];
-                if ((uint32_t) id >= ctx->cparams.n_batch) {
-                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
-                }
-                ctx->output_ids[id] = i;
-            }
-
-            ctx->n_outputs = n_outputs;
-        }
-    }
-
-    void read_logits(struct llama_context * ctx) {
-        uint64_t logits_size;
-        read_to(&logits_size, sizeof(logits_size));
-
-        if (ctx->logits_size < logits_size) {
-            throw std::runtime_error("logits buffer too small");
-        }
-
-        if (logits_size) {
-            read_to(ctx->logits, logits_size * sizeof(float));
-        }
-    }
-
-    void read_embeddings(struct llama_context * ctx) {
-        uint64_t embeddings_size;
-        read_to(&embeddings_size, sizeof(embeddings_size));
-
-        if (ctx->embd_size < embeddings_size) {
-            throw std::runtime_error("embeddings buffer too small");
-        }
-
-        if (embeddings_size) {
-            read_to(ctx->embd, embeddings_size * sizeof(float));
-        }
-    }
-
-    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
-        struct llama_kv_cache & kv_self = ctx->kv_self;
-
-        if (dest_seq_id != -1) {
-            // single sequence
-
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-
-            llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
-            batch.n_tokens = cell_count;
-            batch.n_seq_tokens = cell_count;
-            batch.n_seqs = 1;
-
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_pos pos;
-                uint32_t n_seq_id;
-
-                read_to(&pos, sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
-
-                if (n_seq_id != 0) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                    return false;
-                }
-
-                batch.pos[i] = pos;
-            }
-            batch.n_seq_id[0] = 1;
-            batch.seq_id[0] = &dest_seq_id;
-            if (!llama_kv_cache_find_slot(kv_self, batch)) {
-                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-                return false;
-            }
-
-            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-            // Assume that this is one contiguous block of cells
-            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-        } else {
-            // whole KV cache restore
-
-            if (cell_count > kv_self.size) {
-                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-                return false;
-            }
-
-            llama_kv_cache_clear(kv_self);
-
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_kv_cell & cell = kv_self.cells[i];
-
-                llama_pos pos;
-                uint32_t  n_seq_id;
-
-                read_to(&pos,      sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
-
-                cell.pos = pos;
-
-                for (uint32_t j = 0; j < n_seq_id; ++j) {
-                    llama_seq_id seq_id;
-                    read_to(&seq_id, sizeof(seq_id));
-
-                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
-                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
-                        return false;
-                    }
-
-                    cell.seq_id.insert(seq_id);
-
-                    if (kv_self.recurrent) {
-                        int32_t & tail = kv_self.cells[seq_id].tail;
-                        if (tail != -1) {
-                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                            return false;
-                        }
-                        tail = i;
-                    }
-                }
-            }
-
-            kv_self.head = 0;
-            kv_self.used = cell_count;
-        }
-
-        if (kv_self.recurrent) {
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                uint32_t cell_id = kv_self.head + i;
-                // make sure the recurrent states will keep their restored state
-                kv_self.cells[cell_id].src = cell_id;
-            }
-        }
-
-        return true;
-    }
-
-    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
-        const struct llama_hparams & hparams = ctx->model.hparams;
-        struct llama_kv_cache & kv_self = ctx->kv_self;
-        uint32_t v_trans;
-        uint32_t n_layer;
-        read_to(&v_trans, sizeof(v_trans));
-        read_to(&n_layer, sizeof(n_layer));
-
-        if (n_layer != hparams.n_layer) {
-            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
-            return false;
-        }
-        if (cell_count > kv_self.size) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
-            return false;
-        }
-        if (kv_self.v_trans != (bool) v_trans) {
-            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-            return false;
-        }
-
-        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-            // Read type of key
-            int32_t k_type_i_ref;
-            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            if (k_type_i != k_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-                return false;
-            }
-
-            // Read row size of key
-            uint64_t k_size_row_ref;
-            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            if (k_size_row != k_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // Read and set the keys for the whole cell range
-                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
-            }
-        }
-
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
-
-                // Read row size of value
-                uint64_t v_size_row_ref;
-                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                if (v_size_row != v_size_row_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                    return false;
-                }
-
-                if (cell_count) {
-                    // Read and set the values for the whole cell range
-                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
-                }
-            }
-        } else {
-            // For each layer, read the values for each cell (transposed)
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
-
-                // Read element size of value
-                uint32_t v_size_el_ref;
-                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                if (v_size_el != v_size_el_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                    return false;
-                }
-
-                // Read GQA embedding size
-                uint32_t n_embd_v_gqa_ref;
-                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                    return false;
-                }
-
-                if (cell_count) {
-                    // For each row in the transposed matrix, read the values for the whole cell range
-                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
-                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
-                    }
-                }
-            }
-        }
-        return true;
-    }
-
-    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        uint32_t cell_count;
-        read_to(&cell_count, sizeof(cell_count));
-
-        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
-
-        if (!res) {
-            if (seq_id == -1) {
-                llama_kv_cache_clear(ctx);
-            } else {
-                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
-            }
-            throw std::runtime_error("failed to restore kv cache");
-        }
-    }
-};
-
-struct llama_data_write_dummy : llama_data_write {
-    size_t size_written = 0;
-
-    llama_data_write_dummy() {}
-
-    void write(const void * /* src */, size_t size) override {
-        size_written += size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
-        size_written += size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_write_buffer : llama_data_write {
-    uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_written = 0;
-
-    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
-
-    void write(const void * src, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ggml_backend_tensor_get(tensor, ptr, offset, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_read_buffer : llama_data_read {
-    const uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_read = 0;
-
-    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
-
-    const uint8_t * read(size_t size) override {
-        const uint8_t * base_ptr = ptr;
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ptr += size;
-        size_read += size;
-        buf_size -= size;
-        return base_ptr;
-    }
-
-    void read_to(void * dst, size_t size) override {
-        memcpy(dst, read(size), size);
-    }
-
-    size_t get_size_read() override {
-        return size_read;
-    }
-};
-
-struct llama_data_write_file : llama_data_write {
-    llama_file * file;
-    size_t size_written = 0;
-    std::vector temp_buffer;
-
-    llama_data_write_file(llama_file * f) : file(f) {}
-
-    void write(const void * src, size_t size) override {
-        file->write_raw(src, size);
-        size_written += size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
-        temp_buffer.resize(size);
-        ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
-        write(temp_buffer.data(), temp_buffer.size());
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_read_file : llama_data_read {
-    llama_file * file;
-    size_t size_read = 0;
-    std::vector temp_buffer;
-
-    llama_data_read_file(llama_file * f) : file(f) {}
-
-    void read_to(void * dst, size_t size) override {
-        file->read_raw(dst, size);
-        size_read += size;
-    }
-
-    const uint8_t * read(size_t size) override {
-        temp_buffer.resize(size);
-        read_to(temp_buffer.data(), size);
-        return temp_buffer.data();
-    }
-
-    size_t get_size_read() override {
-        return size_read;
-    }
-};
-
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_write_file data_ctx(&file);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
- * buffer context:
- * std::vector buf(max_size, 0);
- * llama_data_write_buffer data_ctx(buf.data(), max_size);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
-*/
-static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_model_info(ctx);
-
-    // copy outputs
-    data_ctx.write_output_ids(ctx);
-    data_ctx.write_logits(ctx);
-    data_ctx.write_embeddings(ctx);
-
-    data_ctx.write_kv_cache(ctx);
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
-    llama_data_write_buffer data_ctx(dst, size);
-    try {
-        return llama_state_get_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-// Returns the *actual* size of the state.
-// Intended to be used when saving to state to a buffer.
-size_t llama_state_get_size(struct llama_context * ctx) {
-    llama_data_write_dummy data_ctx;
-    try {
-        return llama_state_get_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_model_info(ctx);
-
-    // set outputs
-    data_ctx.read_output_ids(ctx);
-    data_ctx.read_logits(ctx);
-    data_ctx.read_embeddings(ctx);
-
-    data_ctx.read_kv_cache(ctx);
-
-    return data_ctx.get_size_read();
-}
-
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
-    llama_data_read_buffer data_ctx(src, size);
-    try {
-        return llama_state_set_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(path_session, "rb");
-
-    // sanity checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
-            LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
-            return false;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return false;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t n_state_size_cur = file.size - file.tell();
-
-        llama_data_read_file data_ctx(&file);
-        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
-
-        if (n_read != n_state_size_cur) {
-            LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
-            return false;
-        }
-    }
-    return true;
-}
-
-bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
-        return false;
-    }
-}
-
-static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(path_session, "wb");
-
-    file.write_u32(LLAMA_SESSION_MAGIC);
-    file.write_u32(LLAMA_SESSION_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_get_data_internal(ctx, data_ctx);
-
-    return true;
-}
-
-bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
-        return false;
-    }
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_kv_cache(ctx, seq_id);
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_data_write_dummy data_ctx;
-    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-}
-
-size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
-    llama_data_write_buffer data_ctx(dst, size);
-    try {
-        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_kv_cache(ctx, dest_seq_id);
-
-    return data_ctx.get_size_read();
-}
-
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
-    llama_data_read_buffer data_ctx(src, size);
-    try {
-        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(filepath, "wb");
-
-    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
-    file.write_u32(LLAMA_STATE_SEQ_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-
-    const size_t res = file.tell();
-    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
-    return res;
-}
-
-static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(filepath, "rb");
-
-    // version checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
-            LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
-            return 0;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return 0;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t state_size = file.size - file.tell();
-        llama_data_read_file data_ctx(&file);
-        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
-        if (!nread) {
-            LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
-            return 0;
-        }
-        GGML_ASSERT(nread <= state_size);
-        GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
-    }
-
-    return file.tell();
-}
-
-size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
-    ctx->cparams.n_threads       = n_threads;
-    ctx->cparams.n_threads_batch = n_threads_batch;
-}
-
-int32_t llama_n_threads(struct llama_context * ctx) {
-    return ctx->cparams.n_threads;
-}
-
-int32_t llama_n_threads_batch(struct llama_context * ctx) {
-    return ctx->cparams.n_threads_batch;
-}
-
-void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
-    ctx->abort_callback      = abort_callback;
-    ctx->abort_callback_data = abort_callback_data;
-
-    for (auto & backend : ctx->backends) {
-        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
-        auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
-        if (set_abort_callback_fn) {
-            set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
-        }
-    }
-}
-
-void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
-    ctx->cparams.embeddings = embeddings;
-}
-
-void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
-    ctx->cparams.causal_attn = causal_attn;
-}
-
-struct llama_batch llama_batch_get_one(
-             llama_token * tokens,
-                 int32_t   n_tokens) {
-    return {
-        /*n_tokens       =*/ n_tokens,
-        /*tokens         =*/ tokens,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
-    };
-}
-
-struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
-    llama_batch batch = {
-        /*n_tokens       =*/ 0,
-        /*tokens         =*/ nullptr,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
-    };
-
-    if (embd) {
-        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
-    } else {
-        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
-    }
-
-    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
-    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
-    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
-    for (int i = 0; i < n_tokens_alloc; ++i) {
-        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
-    }
-    batch.seq_id[n_tokens_alloc] = nullptr;
-
-    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
-
-    return batch;
-}
-
-void llama_batch_free(struct llama_batch batch) {
-    if (batch.token)    free(batch.token);
-    if (batch.embd)     free(batch.embd);
-    if (batch.pos)      free(batch.pos);
-    if (batch.n_seq_id) free(batch.n_seq_id);
-    if (batch.seq_id) {
-        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
-            free(batch.seq_id[i]);
-        }
-        free(batch.seq_id);
-    }
-    if (batch.logits)   free(batch.logits);
-}
+///
 
 int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_encode_internal(*ctx, batch);
+    const int ret = llama_encode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
@@ -22558,7 +12282,7 @@ int32_t llama_encode(
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_decode_internal(*ctx, batch);
+    const int ret = llama_decode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -22566,150 +12290,12 @@ int32_t llama_decode(
     return ret;
 }
 
-void llama_synchronize(struct llama_context * ctx) {
-    ggml_backend_sched_synchronize(ctx->sched.get());
-
-    // FIXME: if multiple single tokens are evaluated without a synchronization,
-    // the stats will be added to the prompt evaluation stats
-    // this should only happen when using batch size 1 to evaluate a batch
-
-    // add the evaluation to the stats
-    if (ctx->n_queued_tokens == 1) {
-        if (!ctx->cparams.no_perf) {
-            ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        }
-        ctx->n_eval++;
-    } else if (ctx->n_queued_tokens > 1) {
-        if (!ctx->cparams.no_perf) {
-            ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        }
-        ctx->n_p_eval += ctx->n_queued_tokens;
-    }
-
-    // get a more accurate load time, upon first eval
-    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
-        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
-        ctx->has_evaluated_once = true;
-    }
-
-    ctx->n_queued_tokens = 0;
-    ctx->t_compute_start_us = 0;
-}
-
-float * llama_get_logits(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    // reorder logits for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(ctx);
-
-    return ctx->logits;
-}
-
-float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->logits == nullptr) {
-            throw std::runtime_error("no logits");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->logits + j*ctx->model.hparams.n_vocab;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ABORT("fatal error");
-#else
-        return nullptr;
-#endif
-    }
-}
-
-float * llama_get_embeddings(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    // reorder embeddings for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(ctx);
-
-    return ctx->embd;
-}
-
-float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->embd == nullptr) {
-            throw std::runtime_error("no embeddings");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->embd + j*ctx->model.hparams.n_embd;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ABORT("fatal error");
-#else
-        return nullptr;
-#endif
-    }
-}
-
-float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    auto it = ctx->embd_seq.find(seq_id);
-    if (it == ctx->embd_seq.end()) {
-        return nullptr;
-    }
-
-    return it->second.data();
-}
-
 //
 // vocab
 //
 
+// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
+
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
     return llama_token_get_text_impl(model->vocab, token);
 }
@@ -22842,478 +12428,6 @@ int32_t llama_detokenize(
 // chat templates
 //
 
-static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
-    if (LLM_CHAT_TEMPLATES.find(tmpl) != LLM_CHAT_TEMPLATES.end()) {
-        return LLM_CHAT_TEMPLATES.at(tmpl);
-    }
-    auto tmpl_contains = [&tmpl](const char * haystack) -> bool {
-        return tmpl.find(haystack) != std::string::npos;
-    };
-    if (tmpl_contains("<|im_start|>")) {
-        return LLM_CHAT_TEMPLATE_CHATML;
-    } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
-        if (tmpl_contains("[SYSTEM_PROMPT]")) {
-            return LLM_CHAT_TEMPLATE_MISTRAL_V7;
-        } else if (
-            // catches official 'v1' template
-            tmpl_contains("' [INST] ' + system_message")
-            // catches official 'v3' and 'v3-tekken' templates
-            || tmpl_contains("[AVAILABLE_TOOLS]")
-        ) {
-            // Official mistral 'v1', 'v3' and 'v3-tekken' templates
-            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
-            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
-            if (tmpl_contains(" [INST]")) {
-                return LLM_CHAT_TEMPLATE_MISTRAL_V1;
-            } else if (tmpl_contains("\"[INST]\"")) {
-                return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN;
-            }
-            return LLM_CHAT_TEMPLATE_MISTRAL_V3;
-        } else {
-            // llama2 template and its variants
-            // [variant] support system message
-            // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
-            bool support_system_message = tmpl_contains("<>");
-            bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
-            bool strip_message = tmpl_contains("content.strip()");
-            if (strip_message) {
-                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
-            } else if (add_bos_inside_history) {
-                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
-            } else if (support_system_message) {
-                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS;
-            } else {
-                return LLM_CHAT_TEMPLATE_LLAMA_2;
-            }
-        }
-    } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
-        return LLM_CHAT_TEMPLATE_PHI_3;
-    } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
-        return LLM_CHAT_TEMPLATE_FALCON_3;
-    } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
-        return LLM_CHAT_TEMPLATE_ZEPHYR;
-    } else if (tmpl_contains("bos_token + message['role']")) {
-        return LLM_CHAT_TEMPLATE_MONARCH;
-    } else if (tmpl_contains("")) {
-        return LLM_CHAT_TEMPLATE_GEMMA;
-    } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
-        // OrionStarAI/Orion-14B-Chat
-        return LLM_CHAT_TEMPLATE_ORION;
-    } else if (tmpl_contains("GPT4 Correct ")) {
-        // openchat/openchat-3.5-0106
-        return LLM_CHAT_TEMPLATE_OPENCHAT;
-    } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) {
-        // eachadea/vicuna-13b-1.1 (and Orca variant)
-        if (tmpl_contains("SYSTEM: ")) {
-            return LLM_CHAT_TEMPLATE_VICUNA_ORCA;
-        }
-        return LLM_CHAT_TEMPLATE_VICUNA;
-    } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) {
-        // deepseek-ai/deepseek-coder-33b-instruct
-        return LLM_CHAT_TEMPLATE_DEEPSEEK;
-    } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) {
-        // CohereForAI/c4ai-command-r-plus
-        return LLM_CHAT_TEMPLATE_COMMAND_R;
-    } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
-        return LLM_CHAT_TEMPLATE_LLAMA_3;
-    } else if (tmpl_contains("[gMASK]sop")) {
-        // chatglm3-6b
-        return LLM_CHAT_TEMPLATE_CHATGML_3;
-    } else if (tmpl_contains("[gMASK]")) {
-        return LLM_CHAT_TEMPLATE_CHATGML_4;
-    } else if (tmpl_contains(LU8("<用户>"))) {
-        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
-        return LLM_CHAT_TEMPLATE_MINICPM;
-    } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
-        return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
-    } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
-        // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
-        // EXAONE-3.0-7.8B-Instruct
-        return LLM_CHAT_TEMPLATE_EXAONE_3;
-    } else if (tmpl_contains("rwkv-world")) {
-        return LLM_CHAT_TEMPLATE_RWKV_WORLD;
-    } else if (tmpl_contains("<|start_of_role|>")) {
-        return LLM_CHAT_TEMPLATE_GRANITE;
-    } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
-        return LLM_CHAT_TEMPLATE_GIGACHAT;
-    } else if (tmpl_contains("<|role_start|>")) {
-        return LLM_CHAT_TEMPLATE_MEGREZ;
-    }
-    return LLM_CHAT_TEMPLATE_UNKNOWN;
-}
-
-// Simple version of "llama_apply_chat_template" that only works with strings
-// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
-static int32_t llama_chat_apply_template_internal(
-    const llm_chat_template tmpl,
-    const std::vector & chat,
-    std::string & dest, bool add_ass) {
-    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
-    std::stringstream ss;
-    if (tmpl == LLM_CHAT_TEMPLATE_CHATML) {
-        // chatml template
-        for (auto message : chat) {
-            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|im_start|>assistant\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
-        // Official mistral 'v7' template
-        // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
-        for (auto message : chat) {
-            std::string role(message->role);
-            std::string content(message->content);
-            if (role == "system") {
-                ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
-            } else if (role == "user") {
-                ss << "[INST] " << content << "[/INST]";
-            }
-            else {
-                ss << " " << content << "";
-            }
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
-            || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3
-            || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) {
-        // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
-        // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
-        std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : "";
-        std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " ";
-        bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3;
-        bool is_inside_turn = false;
-        for (auto message : chat) {
-            if (!is_inside_turn) {
-                ss << leading_space << "[INST]" << trailing_space;
-                is_inside_turn = true;
-            }
-            std::string role(message->role);
-            std::string content(message->content);
-            if (role == "system") {
-                ss << content << "\n\n";
-            } else if (role == "user") {
-                ss << content << leading_space << "[/INST]";
-            } else {
-                ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "";
-                is_inside_turn = false;
-            }
-        }
-    } else if (
-            tmpl == LLM_CHAT_TEMPLATE_LLAMA_2
-            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS
-            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS
-            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) {
-        // llama2 template and its variants
-        // [variant] support system message
-        // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
-        bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2;
-        // [variant] add BOS inside history
-        bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
-        // [variant] trim spaces from the input message
-        bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
-        // construct the prompt
-        bool is_inside_turn = true; // skip BOS at the beginning
-        ss << "[INST] ";
-        for (auto message : chat) {
-            std::string content = strip_message ? trim(message->content) : message->content;
-            std::string role(message->role);
-            if (!is_inside_turn) {
-                is_inside_turn = true;
-                ss << (add_bos_inside_history ? "[INST] " : "[INST] ");
-            }
-            if (role == "system") {
-                if (support_system_message) {
-                    ss << "<>\n" << content << "\n<>\n\n";
-                } else {
-                    // if the model does not support system message, we still include it in the first message, but without <>
-                    ss << content << "\n";
-                }
-            } else if (role == "user") {
-                ss << content << " [/INST]";
-            } else {
-                ss << content << "";
-                is_inside_turn = false;
-            }
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) {
-        // Phi 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
-        // Falcon 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>\n" << message->content << "\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
-        // zephyr template
-        for (auto message : chat) {
-            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) {
-        // mlabonne/AlphaMonarch-7B template (the  is included inside history)
-        for (auto message : chat) {
-            std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message
-            ss << bos << message->role << "\n" << message->content << "\n";
-        }
-        if (add_ass) {
-            ss << "assistant\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) {
-        // google/gemma-7b-it
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
-                system_prompt = trim(message->content);
-                continue;
-            }
-            // in gemma, "assistant" is "model"
-            role = role == "assistant" ? "model" : message->role;
-            ss << "" << role << "\n";
-            if (!system_prompt.empty() && role != "model") {
-                ss << system_prompt << "\n\n";
-                system_prompt = "";
-            }
-            ss << trim(message->content) << "\n";
-        }
-        if (add_ass) {
-            ss << "model\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) {
-        // OrionStarAI/Orion-14B-Chat
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message support, we will merge it with user prompt
-                system_prompt = message->content;
-                continue;
-            } else if (role == "user") {
-                ss << "Human: ";
-                if (!system_prompt.empty()) {
-                    ss << system_prompt << "\n\n";
-                    system_prompt = "";
-                }
-                ss << message->content << "\n\nAssistant: ";
-            } else {
-                ss << message->content << "";
-            }
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) {
-        // openchat/openchat-3.5-0106,
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content << "<|end_of_turn|>";
-            } else {
-                role[0] = toupper(role[0]);
-                ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
-            }
-        }
-        if (add_ass) {
-            ss << "GPT4 Correct Assistant:";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
-        // eachadea/vicuna-13b-1.1 (and Orca variant)
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // Orca-Vicuna variant uses a system prefix
-                if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
-                    ss << "SYSTEM: " << message->content << "\n";
-                } else {
-                    ss << message->content << "\n\n";
-                }
-            } else if (role == "user") {
-                ss << "USER: " << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "ASSISTANT: " << message->content << "\n";
-            }
-        }
-        if (add_ass) {
-            ss << "ASSISTANT:";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) {
-        // deepseek-ai/deepseek-coder-33b-instruct
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content;
-            } else if (role == "user") {
-                ss << "### Instruction:\n" << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
-            }
-        }
-        if (add_ass) {
-            ss << "### Response:\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) {
-        // CohereForAI/c4ai-command-r-plus
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "user") {
-                ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "assistant") {
-                ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            }
-        }
-        if (add_ass) {
-            ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) {
-        // Llama 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
-        }
-        if (add_ass) {
-            ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
-        // chatglm3-6b
-        ss << "[gMASK]" << "sop";
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>" << "\n " << message->content;
-        }
-        if (add_ass) {
-            ss << "<|assistant|>";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
-        ss << "[gMASK]" << "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>" << "\n" << message->content;
-        }
-        if (add_ass) {
-            ss << "<|assistant|>";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
-        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "user") {
-                ss << LU8("<用户>");
-                ss << trim(message->content);
-                ss << "";
-            } else {
-                ss << trim(message->content);
-            }
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) {
-        // DeepSeek-V2
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content << "\n\n";
-            } else if (role == "user") {
-                ss << "User: " << message->content << "\n\n";
-            } else if (role == "assistant") {
-                ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
-            }
-        }
-        if (add_ass) {
-            ss << "Assistant:";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
-        // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
-        // EXAONE-3.0-7.8B-Instruct
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
-            } else if (role == "user") {
-                ss << "[|user|]" << trim(message->content) << "\n";
-            } else if (role == "assistant") {
-                ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
-            }
-        }
-        if (add_ass) {
-            ss << "[|assistant|]";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
-        // this template requires the model to have "\n\n" as EOT token
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "user") {
-                ss << "User: " << message->content << "\n\nAssistant:";
-            } else {
-                ss << message->content << "\n\n";
-            }
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
-        // IBM Granite template
-        for (const auto & message : chat) {
-            std::string role(message->role);
-            ss << "<|start_of_role|>" << role << "<|end_of_role|>";
-            if (role == "assistant_tool_call") {
-                ss << "<|tool_call|>";
-            }
-            ss << message->content << "<|end_of_text|>\n";
-        }
-        if (add_ass) {
-            ss << "<|start_of_role|>assistant<|end_of_role|>\n";
-        }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
-        // GigaChat template
-        bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
-
-        // Handle system message if present
-        if (has_system) {
-            ss << "" << chat[0]->content << "<|message_sep|>";
-        } else {
-            ss << "";
-        }
-
-        // Process remaining messages
-        for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
-            std::string role(chat[i]->role);
-            if (role == "user") {
-                ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
-                << "available functions<|role_sep|>[]<|message_sep|>";
-            } else if (role == "assistant") {
-                ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
-            }
-        }
-
-        // Add generation prompt if needed
-        if (add_ass) {
-            ss << "assistant<|role_sep|>";
-        }
-    }  else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) {
-        // Megrez template
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>";
-        }
-
-        if (add_ass) {
-            ss << "<|role_start|>assistant<|role_end|>";
-        }
-    } else {
-        // template not supported
-        return -1;
-    }
-    dest = ss.str();
-    return dest.size();
-}
-
 int32_t llama_chat_apply_template(
                 const struct llama_model * model,
                               const char * tmpl,
@@ -23333,7 +12447,7 @@ int32_t llama_chat_apply_template(
         }
         else {
             // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml";  // see llama_chat_apply_template_internal
+            curr_tmpl = "chatml";  // see llm_chat_apply_template
         }
     }
 
@@ -23345,11 +12459,11 @@ int32_t llama_chat_apply_template(
     }
 
     std::string formatted_chat;
-    llm_chat_template detected_tmpl = llama_chat_detect_template(curr_tmpl);
+    llm_chat_template detected_tmpl = llm_chat_detect_template(curr_tmpl);
     if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) {
         return -1;
     }
-    int32_t res = llama_chat_apply_template_internal(detected_tmpl, chat_vec, formatted_chat, add_ass);
+    int32_t res = llm_chat_apply_template(detected_tmpl, chat_vec, formatted_chat, add_ass);
     if (res < 0) {
         return res;
     }
@@ -23359,15 +12473,6 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
-int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
-    auto it = LLM_CHAT_TEMPLATES.begin();
-    for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) {
-        output[i] = it->first.c_str();
-        std::advance(it, 1);
-    }
-    return (int32_t) LLM_CHAT_TEMPLATES.size();
-}
-
 //
 // sampling
 //
@@ -23397,16 +12502,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
     return 0;
 }
 
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
     std::string str_split_path(split_path);
     char postfix[32];
     snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
     std::string str_postfix(postfix);
 
-    // check if dest ends with postfix
+    // check if split_prefix ends with postfix
     int size_prefix = str_split_path.size() - str_postfix.size();
     if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
         return size_prefix;
     }
 
@@ -23415,6 +12520,8 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
 
 const char * llama_print_system_info(void) {
     static std::string s;
+    s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
+
 
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         auto * reg = ggml_backend_reg_get(i);
@@ -23435,6 +12542,10 @@ const char * llama_print_system_info(void) {
     return s.c_str();
 }
 
+//
+// perf
+//
+
 struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
     struct llama_perf_context_data data = {};
 
@@ -23470,47 +12581,3 @@ void llama_perf_context_reset(struct llama_context * ctx) {
     ctx->t_eval_us   = ctx->n_eval = 0;
     ctx->t_p_eval_us = ctx->n_p_eval = 0;
 }
-
-// For internal test use
-const std::vector> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-) {
-    return ctx->model.tensors_by_name;
-}
-
-void llama_log_set(ggml_log_callback log_callback, void * user_data) {
-    ggml_log_set(log_callback, user_data);
-    g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
-    g_logger_state.log_callback_user_data = user_data;
-}
-
-static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
-    va_list args_copy;
-    va_copy(args_copy, args);
-    char buffer[128];
-    int len = vsnprintf(buffer, 128, format, args);
-    if (len < 128) {
-        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
-    } else {
-        char * buffer2 = new char[len + 1];
-        vsnprintf(buffer2, len + 1, format, args_copy);
-        buffer2[len] = 0;
-        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
-        delete[] buffer2;
-    }
-    va_end(args_copy);
-}
-
-void llama_log_internal(ggml_log_level level, const char * format, ...) {
-    va_list args;
-    va_start(args, format);
-    llama_log_internal_v(level, format, args);
-    va_end(args);
-}
-
-void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
-    (void) level;
-    (void) user_data;
-    fputs(text, stderr);
-    fflush(stderr);
-}
diff --git a/src/unicode.cpp b/src/unicode.cpp
index 8ed6b1a51..7aca6544b 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -667,18 +667,24 @@ std::vector unicode_regex_split(const std::string & text, const std
         { "\\p{N}", unicode_cpt_flags::NUMBER },
         { "\\p{L}", unicode_cpt_flags::LETTER },
         { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
     };
 
     static const std::map k_ucat_cpt = {
         { unicode_cpt_flags::NUMBER,      0xD1 },
         { unicode_cpt_flags::LETTER,      0xD2 },
         { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
     };
 
     static const std::map k_ucat_map = {
         { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
         { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
         { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
     };
 
     // compute collapsed codepoints only if needed by at least one regex
diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp
index 57fa00011..ba084a91a 100644
--- a/tests/test-autorelease.cpp
+++ b/tests/test-autorelease.cpp
@@ -13,10 +13,10 @@ int main(int argc, char ** argv) {
 
     std::thread([&model_path]() {
         llama_backend_init();
-        auto * model = llama_load_model_from_file(model_path, llama_model_default_params());
+        auto * model = llama_model_load_from_file(model_path, llama_model_default_params());
         auto * ctx = llama_new_context_with_model(model, llama_context_default_params());
         llama_free(ctx);
-        llama_free_model(model);
+        llama_model_free(model);
         llama_backend_free();
     }).join();
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index c79acffd2..1e892f663 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3937,7 +3937,7 @@ static std::vector> make_test_cases_perf() {
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
 
-    for (int bs : {1, 512}) {
+    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
         for (ggml_type type_a : all_types) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {
                 test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index 51bfb155b..f1f9aec4d 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -78,7 +78,9 @@ int main(void) {
         // ai-sage/GigaChat-20B-A3B-instruct
         "{% if messages[0]['role'] == 'system' -%}\n    {%- set loop_messages = messages[1:] -%}\n    {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n    {%- set loop_messages = messages -%}\n    {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n    {% endif %}\n    \n    {%- if loop.index0 == 0 -%}\n        {{ system_message -}}\n    {%- endif -%}\n    {%- if message['role'] == 'user' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n        {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3]  + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if message['role'] == 'assistant' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if loop.last and add_generation_prompt -%}\n        {{ 'assistant' + additional_special_tokens[0] -}}\n    {%- endif -%}\n{%- endfor %}",
         // Infinigence/Megrez-3B-Instruct
-        u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"
+        u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
+        // phi-4
+        "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
     };
     std::vector expected_output = {
         // teknium/OpenHermes-2.5-Mistral-7B
@@ -137,6 +139,8 @@ int main(void) {
         "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>   I am an assistant   <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
         // Infinigence/Megrez-3B-Instruct
         "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|>   I am an assistant   <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
+        // phi-4
+        "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|>   I am an assistant   <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
     };
     std::vector formatted_chat(1024);
     int32_t res;
diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp
index 1bb5fb47c..611957ac0 100644
--- a/tests/test-gguf.cpp
+++ b/tests/test-gguf.cpp
@@ -15,66 +15,71 @@ constexpr int offset_has_tensors = 2000;
 constexpr int offset_has_data    = 3000;
 
 enum handcrafted_file_type {
-    HANDCRAFTED_HEADER_BAD_MAGIC          =  10,
-    HANDCRAFTED_HEADER_BAD_VERSION_1      =  20,
-    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE =  30,
-    HANDCRAFTED_HEADER_BAD_N_TENSORS      =  40,
-    HANDCRAFTED_HEADER_BAD_N_KV           =  50,
-    HANDCRAFTED_HEADER_EMPTY              = 800,
+    HANDCRAFTED_HEADER_BAD_MAGIC           =  10,
+    HANDCRAFTED_HEADER_BAD_VERSION_1       =  20,
+    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE  =  30,
+    HANDCRAFTED_HEADER_BAD_N_TENSORS       =  40,
+    HANDCRAFTED_HEADER_BAD_N_KV            =  50,
+    HANDCRAFTED_HEADER_EMPTY               = 800,
 
-    HANDCRAFTED_KV_BAD_KEY_SIZE           =  10 + offset_has_kv,
-    HANDCRAFTED_KV_BAD_TYPE               =  20 + offset_has_kv,
-    HANDCRAFTED_KV_BAD_VALUE_SIZE         =  30 + offset_has_kv,
-    HANDCRAFTED_KV_DUPLICATE_KEY          =  40 + offset_has_kv,
-    HANDCRAFTED_KV_SUCCESS                = 800 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_KEY_SIZE            =  10 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_TYPE                =  20 + offset_has_kv,
+    // HANDCRAFTED_KV_BAD_VALUE_SIZE          =  30 + offset_has_kv, // removed because it can result in allocations > 1 TB (default sanitizer limit)
+    HANDCRAFTED_KV_DUPLICATE_KEY           =  40 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_ALIGN               =  50 + offset_has_kv,
+    HANDCRAFTED_KV_SUCCESS                 = 800 + offset_has_kv,
 
-    HANDCRAFTED_TENSORS_BAD_NAME_SIZE     =  10 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_N_DIMS        =  20 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_SHAPE         =  30 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_NE_TOO_BIG        =  40 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_TYPE          =  50 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_OFFSET        =  60 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_DUPLICATE_NAME    =  70 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_ALIGNMENT     =  80 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_SUCCESS           = 800 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_CUSTOM_ALIGN      = 810 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_NAME_SIZE      =  10 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_N_DIMS         =  20 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_SHAPE          =  30 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_NE_TOO_BIG         =  40 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_TYPE           =  50 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_OFFSET         =  60 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_DUPLICATE_NAME     =  70 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_ALIGN          =  75 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN =  80 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_SUCCESS            = 800 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_CUSTOM_ALIGN       = 810 + offset_has_tensors,
 
-    HANDCRAFTED_DATA_NOT_ENOUGH_DATA      =  10 + offset_has_data,
-    HANDCRAFTED_DATA_BAD_ALIGNMENT        =  20 + offset_has_data,
-    HANDCRAFTED_DATA_SUCCESS              = 800 + offset_has_data,
-    HANDCRAFTED_DATA_CUSTOM_ALIGN         = 810 + offset_has_data,
+    HANDCRAFTED_DATA_NOT_ENOUGH_DATA       =  10 + offset_has_data,
+    HANDCRAFTED_DATA_BAD_ALIGN             =  15 + offset_has_data,
+    HANDCRAFTED_DATA_INCONSISTENT_ALIGN    =  20 + offset_has_data,
+    HANDCRAFTED_DATA_SUCCESS               = 800 + offset_has_data,
+    HANDCRAFTED_DATA_CUSTOM_ALIGN          = 810 + offset_has_data,
 };
 
 std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
     switch (hft) {
-        case HANDCRAFTED_HEADER_BAD_MAGIC:          return "HEADER_BAD_MAGIC";
-        case HANDCRAFTED_HEADER_BAD_VERSION_1:      return "HEADER_BAD_VERSION_1";
-        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE: return "HEADER_BAD_VERSION_FUTURE";
-        case HANDCRAFTED_HEADER_BAD_N_KV:           return "HEADER_BAD_N_KV";
-        case HANDCRAFTED_HEADER_BAD_N_TENSORS:      return "HEADER_BAD_N_TENSORS";
-        case HANDCRAFTED_HEADER_EMPTY:              return "HEADER_EMPTY";
+        case HANDCRAFTED_HEADER_BAD_MAGIC:           return "HEADER_BAD_MAGIC";
+        case HANDCRAFTED_HEADER_BAD_VERSION_1:       return "HEADER_BAD_VERSION_1";
+        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE:  return "HEADER_BAD_VERSION_FUTURE";
+        case HANDCRAFTED_HEADER_BAD_N_KV:            return "HEADER_BAD_N_KV";
+        case HANDCRAFTED_HEADER_BAD_N_TENSORS:       return "HEADER_BAD_N_TENSORS";
+        case HANDCRAFTED_HEADER_EMPTY:               return "HEADER_EMPTY";
 
-        case HANDCRAFTED_KV_BAD_KEY_SIZE:           return "KV_BAD_KEY_SIZE";
-        case HANDCRAFTED_KV_BAD_TYPE:               return "KV_BAD_TYPE";
-        case HANDCRAFTED_KV_BAD_VALUE_SIZE:         return "KV_BAD_VALUE_SIZE";
-        case HANDCRAFTED_KV_DUPLICATE_KEY:          return "KV_DUPLICATE_KEY";
-        case HANDCRAFTED_KV_SUCCESS:                return "KV_RANDOM_KV";
+        case HANDCRAFTED_KV_BAD_KEY_SIZE:            return "KV_BAD_KEY_SIZE";
+        case HANDCRAFTED_KV_BAD_TYPE:                return "KV_BAD_TYPE";
+        case HANDCRAFTED_KV_DUPLICATE_KEY:           return "KV_DUPLICATE_KEY";
+        case HANDCRAFTED_KV_BAD_ALIGN:               return "KV_BAD_ALIGN";
+        case HANDCRAFTED_KV_SUCCESS:                 return "KV_RANDOM_KV";
 
-        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:     return "TENSORS_BAD_NAME_SIZE";
-        case HANDCRAFTED_TENSORS_BAD_N_DIMS:        return "TENSORS_BAD_N_DIMS";
-        case HANDCRAFTED_TENSORS_BAD_SHAPE:         return "TENSORS_BAD_SHAPE";
-        case HANDCRAFTED_TENSORS_NE_TOO_BIG:        return "TENSORS_NE_TOO_BIG";
-        case HANDCRAFTED_TENSORS_BAD_TYPE:          return "TENSORS_BAD_TYPE";
-        case HANDCRAFTED_TENSORS_BAD_OFFSET:        return "TENSORS_BAD_OFFSET";
-        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:    return "TENSORS_DUPLICATE_NAME";
-        case HANDCRAFTED_TENSORS_BAD_ALIGNMENT:     return "TENSORS_BAD_ALIGNMENT";
-        case HANDCRAFTED_TENSORS_SUCCESS:           return "TENSORS_SUCCESS";
-        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:      return "TENSORS_CUSTOM_ALIGN";
+        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:      return "TENSORS_BAD_NAME_SIZE";
+        case HANDCRAFTED_TENSORS_BAD_N_DIMS:         return "TENSORS_BAD_N_DIMS";
+        case HANDCRAFTED_TENSORS_BAD_SHAPE:          return "TENSORS_BAD_SHAPE";
+        case HANDCRAFTED_TENSORS_NE_TOO_BIG:         return "TENSORS_NE_TOO_BIG";
+        case HANDCRAFTED_TENSORS_BAD_TYPE:           return "TENSORS_BAD_TYPE";
+        case HANDCRAFTED_TENSORS_BAD_OFFSET:         return "TENSORS_BAD_OFFSET";
+        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:     return "TENSORS_DUPLICATE_NAME";
+        case HANDCRAFTED_TENSORS_BAD_ALIGN:          return "TENSORS_BAD_ALIGN";
+        case HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN: return "TENSORS_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_TENSORS_SUCCESS:            return "TENSORS_SUCCESS";
+        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:       return "TENSORS_CUSTOM_ALIGN";
 
-        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:      return "DATA_NOT_ENOUGH_DATA";
-        case HANDCRAFTED_DATA_BAD_ALIGNMENT:        return "DATA_BAD_ALIGNMENT";
-        case HANDCRAFTED_DATA_SUCCESS:              return "DATA_SUCCESS";
-        case HANDCRAFTED_DATA_CUSTOM_ALIGN:         return "DATA_CUSTOM_ALIGN";
+        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:       return "DATA_NOT_ENOUGH_DATA";
+        case HANDCRAFTED_DATA_BAD_ALIGN:             return "DATA_BAD_ALIGN";
+        case HANDCRAFTED_DATA_INCONSISTENT_ALIGN:    return "DATA_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_DATA_SUCCESS:               return "DATA_SUCCESS";
+        case HANDCRAFTED_DATA_CUSTOM_ALIGN:          return "DATA_CUSTOM_ALIGN";
     }
     GGML_ABORT("fatal error");
 }
@@ -140,31 +145,41 @@ std::vector> get_kv_types(std::mt19937
     return kv_types;
 }
 
-static void helper_write(const void * data, const size_t nbytes, FILE * file) {
+template 
+static void helper_write(FILE * file, const T & val) {
+    GGML_ASSERT(fwrite(&val, 1, sizeof(val), file) == sizeof(val));
+}
+
+static void helper_write(FILE * file, const void * data, const size_t nbytes) {
     GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes);
 }
 
 static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) {
     FILE * file = tmpfile();
 
+    if (!file) {
+        return file;
+    }
+
     std::mt19937 rng(seed);
+    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
 
     if (hft == HANDCRAFTED_HEADER_BAD_MAGIC) {
         const char bad_magic[4] = {'F', 'U', 'G', 'G'};
-        helper_write(bad_magic, sizeof(bad_magic), file);
+        helper_write(file, bad_magic, sizeof(bad_magic));
     } else {
-        helper_write(GGUF_MAGIC, 4, file);
+        helper_write(file, GGUF_MAGIC, 4);
     }
 
     if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
         const uint32_t version = 1;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) {
         const uint32_t version = GGUF_VERSION + 1;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     } else {
         const uint32_t version = GGUF_VERSION;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     }
 
     std::vector tensor_configs;
@@ -174,10 +189,10 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
 
     if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
         const uint64_t n_tensors = -1;
-        helper_write(&n_tensors, sizeof(n_tensors), file);
+        helper_write(file, n_tensors);
     } else {
         const uint64_t n_tensors = tensor_configs.size();
-        helper_write(&n_tensors, sizeof(n_tensors), file);
+        helper_write(file, n_tensors);
     }
 
     std::vector> kv_types;
@@ -186,41 +201,49 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
     }
     {
         uint64_t n_kv = kv_types.size();
-        if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+        if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+            hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+            hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
             n_kv += 1;
         } else if (hft == HANDCRAFTED_HEADER_BAD_N_KV) {
             n_kv = -1;
         }
-        helper_write(&n_kv, sizeof(n_kv), file);
+        helper_write(file, n_kv);
     }
 
     if (hft < offset_has_kv) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
         for (int i = 0; i < extra_bytes; ++i) {
             const char tmp = 0;
-            helper_write(&tmp, sizeof(tmp), file);
+            helper_write(file, tmp);
         }
         rewind(file);
         return file;
     }
 
     for (int i = 0; i < int(kv_types.size()); ++i) {
-        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? -1 : kv_types[i].first);
-        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? -1 : kv_types[i].second);
+        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].second);
 
         const std::string key = "my_key_" + std::to_string((hft == HANDCRAFTED_KV_DUPLICATE_KEY ? i/2 : i));
 
         if (hft == HANDCRAFTED_KV_BAD_KEY_SIZE) {
             const uint64_t n = -1;
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         } else {
             const uint64_t n = key.length();
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         }
-        helper_write(key.data(), key.length(), file);
+        helper_write(file, key.data(), key.length());
 
         {
             const int32_t type32 = int32_t(type);
-            helper_write(&type32, sizeof(type32), file);
+            helper_write(file, type32);
         }
 
         uint32_t data[16];
@@ -233,69 +256,67 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
 
         if (type == GGUF_TYPE_STRING) {
             const uint64_t n = rng() % sizeof(data);
-            helper_write(&n,   sizeof(n), file);
-            helper_write(data,        n,  file);
+            helper_write(file, n);
+            helper_write(file, data, n);
             continue;
         }
 
         if (type == GGUF_TYPE_ARRAY) {
             {
                 const int32_t type32 = int32_t(type_arr);
-                helper_write(&type32, sizeof(type32), file);
+                helper_write(file, type32);
             }
             if (type_arr == GGUF_TYPE_STRING) {
                 const uint64_t nstr = rng() % (16 + 1);
-                helper_write(&nstr, sizeof(nstr), file);
+                helper_write(file, nstr);
                 for (uint64_t istr = 0; istr < nstr; ++istr) {
                     const uint64_t n = rng() % (sizeof(uint32_t) + 1);
-                    helper_write(&n,          sizeof(n), file);
-                    helper_write(&data[istr],        n,  file);
+                    helper_write(file, n);
+                    helper_write(file, &data[istr], n);
                 }
                 continue;
             }
             const size_t type_size = gguf_type_size(type_arr);
             const uint64_t n = (rng() % sizeof(data)) / type_size;
-            helper_write(&n,    sizeof(n),   file);
-            helper_write(&data, n*type_size, file);
+            helper_write(file, n);
+            helper_write(file, &data, n*type_size);
             continue;
         }
 
-        size_t type_size = hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type);
-        if (hft == HANDCRAFTED_KV_BAD_VALUE_SIZE) {
-            type_size += rng() % 3;
-        }
-        helper_write(data, type_size, file);
+        helper_write(file, data, hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type));
     }
 
-    if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
-        const std::string key = "general.alignment";
-        {
-            const uint64_t n = key.length();
-            helper_write(&n, sizeof(n), file);
-        }
-        helper_write(key.data(), key.length(), file);
+    if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+        hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+        hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+        const uint64_t n = strlen(GGUF_KEY_GENERAL_ALIGNMENT);
+        helper_write(file, n);
+        helper_write(file, GGUF_KEY_GENERAL_ALIGNMENT, n);
 
         const int32_t type = gguf_type(GGUF_TYPE_UINT32);
-        helper_write(&type, sizeof(type), file);
+        helper_write(file, type);
 
-        const uint32_t alignment = GGUF_DEFAULT_ALIGNMENT + 1;
-        helper_write(&alignment, sizeof(alignment), file);
+        alignment = expect_context_not_null(hft) ? 1 : 13;
+        helper_write(file, alignment);
     }
 
     if (hft < offset_has_tensors) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
         for (int i = 0; i < extra_bytes; ++i) {
             const char tmp = 0;
-            helper_write(&tmp, sizeof(tmp), file);
+            helper_write(file, tmp);
         }
         rewind(file);
         return file;
     }
 
-    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
-    if (hft == HANDCRAFTED_TENSORS_BAD_ALIGNMENT || hft == HANDCRAFTED_DATA_BAD_ALIGNMENT) {
-        alignment -= 1;
-    } else if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
-        alignment += 1;
+    if (hft == HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN || hft == HANDCRAFTED_DATA_INCONSISTENT_ALIGN) {
+        alignment = 1;
     }
 
     uint64_t offset = 0;
@@ -313,9 +334,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         {
             const uint64_t n = name.length();
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         }
-        helper_write(name.data(), name.length(), file);
+        helper_write(file, name.data(), name.length());
 
         uint32_t n_dims = hft == HANDCRAFTED_TENSORS_NE_TOO_BIG ? 2 : 1;
         for (int i = GGML_MAX_DIMS-1; i >= 1; --i) {
@@ -326,35 +347,35 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         if (hft == HANDCRAFTED_TENSORS_BAD_N_DIMS) {
             const uint32_t n_dims_bad = GGML_MAX_DIMS + 1;
-            helper_write(&n_dims_bad, sizeof(n_dims_bad), file);
+            helper_write(file, n_dims_bad);
         } else {
-            helper_write(&n_dims,     sizeof(n_dims),     file);
+            helper_write(file, n_dims);
         }
 
         if (hft == HANDCRAFTED_TENSORS_BAD_SHAPE) {
             for (uint32_t j = 0; j < n_dims; ++j) {
                 const int64_t bad_dim = -1;
-                helper_write(&bad_dim, sizeof(bad_dim), file);
+                helper_write(file, bad_dim);
             }
         } else if (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG){
             for (uint32_t j = 0; j < n_dims; ++j) {
                 const int64_t big_dim = 4*int64_t(INT32_MAX);
-                helper_write(&big_dim, sizeof(big_dim), file);
+                helper_write(file, big_dim);
             }
         } else {
-            helper_write(shape.data(), n_dims*sizeof(int64_t), file);
+            helper_write(file, shape.data(), n_dims*sizeof(int64_t));
         }
 
         {
-            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? -1 : int32_t(type);
-            helper_write(&type32, sizeof(type32), file);
+            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? GGML_TYPE_COUNT : int32_t(type);
+            helper_write(file, type32);
         }
 
         if (hft == HANDCRAFTED_TENSORS_BAD_OFFSET) {
             const uint64_t bad_offset = -1;
-            helper_write(&bad_offset, sizeof(bad_offset), file);
+            helper_write(file, bad_offset);
         } else {
-            helper_write(&offset, sizeof(offset), file);
+            helper_write(file, offset);
         }
 
         int64_t ne = shape[0];
@@ -364,12 +385,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         offset += GGML_PAD(ggml_row_size(type, ne), alignment);
     }
 
-    const uint32_t alignment_overshoot = ftell(file) % alignment;
-    if (alignment_overshoot != 0) {
-        for (size_t i = alignment_overshoot; i < alignment; ++i) {
-            const char pad = 0;
-            helper_write(&pad, sizeof(pad), file);
-        }
+    while (ftell(file) % alignment != 0) {
+        const char pad = 0;
+        helper_write(file, pad);
     }
 
     if (hft >= offset_has_data) {
@@ -380,13 +398,13 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         for (uint64_t i = 0; i < nbytes; ++i) {
             const uint8_t random_byte = i % 256;
-            helper_write(&random_byte, sizeof(random_byte), file);
+            helper_write(file, random_byte);
         }
     }
 
     for (int i = 0; i < extra_bytes; ++i) {
         const char tmp = 0;
-        helper_write(&tmp, sizeof(tmp), file);
+        helper_write(file, tmp);
     }
     rewind(file);
     return file;
@@ -505,6 +523,16 @@ static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned i
             }
 
             const char * data_gguf = reinterpret_cast(gguf_get_arr_data(gguf_ctx, id));
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data8[arr_i]) != bool(data_gguf[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
             if (!std::equal(data8, data8 + arr_n*type_size, data_gguf)) {
                 ok = false;
             }
@@ -512,12 +540,20 @@ static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned i
         }
 
         const char * data_gguf = reinterpret_cast(gguf_get_val_data(gguf_ctx, id));
+
+        if (type == GGUF_TYPE_BOOL) {
+            if (bool(*data8) != bool(*data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
         if (!std::equal(data8, data8 + gguf_type_size(type), data_gguf)) {
             ok = false;
         }
     }
 
-    const uint32_t expected_alignment = alignment_defined ? GGUF_DEFAULT_ALIGNMENT + 1 : GGUF_DEFAULT_ALIGNMENT;
+    const uint32_t expected_alignment = alignment_defined ? 1 : GGUF_DEFAULT_ALIGNMENT;
     if (gguf_get_alignment(gguf_ctx) != expected_alignment) {
         ok = false;
     }
@@ -539,7 +575,7 @@ static bool handcrafted_check_tensors(const gguf_context * gguf_ctx, const unsig
 
     bool ok = true;
 
-    const int id_alignment = gguf_find_key(gguf_ctx, "general.alignment");
+    const int id_alignment = gguf_find_key(gguf_ctx, GGUF_KEY_GENERAL_ALIGNMENT);
     const uint32_t alignment = id_alignment >= 0 ? gguf_get_val_u32(gguf_ctx, id_alignment) : GGUF_DEFAULT_ALIGNMENT;
 
     uint64_t expected_offset = 0;
@@ -607,7 +643,7 @@ static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const u
 
         std::vector data(size);
         GGML_ASSERT(fseek(file, gguf_get_data_offset(gguf_ctx) + offset, SEEK_SET) == 0);
-        GGML_ASSERT(fread(data.data(), 1, size, file) == size);
+        GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size());
 
         for (size_t j = 0; j < size; ++j) {
             const uint8_t expected_byte = (j + offset) % 256;
@@ -627,15 +663,15 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
     const std::vector hfts = {
         HANDCRAFTED_HEADER_BAD_MAGIC,
         HANDCRAFTED_HEADER_BAD_VERSION_1,
-        // HANDCRAFTED_FILE_TYPE_BAD_VERSION_FUTURE, // FIXME
+        HANDCRAFTED_HEADER_BAD_VERSION_FUTURE,
         HANDCRAFTED_HEADER_BAD_N_KV,
         HANDCRAFTED_HEADER_BAD_N_TENSORS,
         HANDCRAFTED_HEADER_EMPTY,
 
         HANDCRAFTED_KV_BAD_KEY_SIZE,
         HANDCRAFTED_KV_BAD_TYPE,
-        // HANDCRAFTED_KV_BAD_VALUE_SIZE, // FIXME sanitizer limit
-        // HANDCRAFTED_FILE_TYPE_DUPLICATE_KEY, // FIXME
+        HANDCRAFTED_KV_DUPLICATE_KEY,
+        HANDCRAFTED_KV_BAD_ALIGN,
         HANDCRAFTED_KV_SUCCESS,
 
         HANDCRAFTED_TENSORS_BAD_NAME_SIZE,
@@ -643,14 +679,16 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
         HANDCRAFTED_TENSORS_BAD_SHAPE,
         HANDCRAFTED_TENSORS_NE_TOO_BIG,
         HANDCRAFTED_TENSORS_BAD_TYPE,
-        // HANDCRAFTED_TENSORS_BAD_OFFSET, // FIXME
+        HANDCRAFTED_TENSORS_BAD_OFFSET,
         HANDCRAFTED_TENSORS_DUPLICATE_NAME,
-        // HANDCRAFTED_TENSORS_BAD_ALIGNMENT, // FIXME
+        HANDCRAFTED_TENSORS_BAD_ALIGN,
+        HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN,
         HANDCRAFTED_TENSORS_SUCCESS,
         HANDCRAFTED_TENSORS_CUSTOM_ALIGN,
 
         HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
-        // HANDCRAFTED_DATA_BAD_ALIGNMENT, // FIXME
+        HANDCRAFTED_DATA_BAD_ALIGN,
+        HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
         HANDCRAFTED_DATA_SUCCESS,
         HANDCRAFTED_DATA_CUSTOM_ALIGN,
     };
@@ -674,6 +712,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
             /*no_alloc =*/ false,
             /*ctx      =*/ hft >= offset_has_data ? &ctx : nullptr,
         };
+
         struct gguf_context * gguf_ctx = gguf_init_from_file_impl(file, gguf_params);
 
         if (expect_context_not_null(hft)) {
@@ -689,7 +728,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
         }
         ntest++;
 
-        if (false && hft >= offset_has_data && !expect_context_not_null(hft)) { // FIXME
+        if (hft >= offset_has_data && !expect_context_not_null(hft)) {
             printf("%s:   - no_dangling_ggml_context_pointer: ", __func__);
             if (ctx) {
                 printf("\033[1;31mFAIL\033[0m\n");
@@ -700,23 +739,6 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
             ntest++;
         }
 
-        if (false && expect_context_not_null(hft)) { // FIXME
-            FILE * file_eb = get_handcrafted_file(seed, hft, /*extra_bytes =*/ 1);
-            struct gguf_context * gguf_ctx_eb = gguf_init_from_file_impl(file_eb, gguf_params);
-
-            printf("%s:   - context_null_with_extra_bytes: ", __func__);
-            if (gguf_ctx_eb) {
-                printf("\033[1;31mFAIL\033[0m\n");
-            } else {
-                printf("\033[1;32mOK\033[0m\n");
-                npass++;
-            }
-            ntest++;
-
-            gguf_free(gguf_ctx_eb);
-            fclose(file_eb);
-        }
-
         const bool alignment_defined = hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN;
 
         if (expect_context_not_null(hft)) {
@@ -763,14 +785,15 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
             ntest++;
         }
 
+        fclose(file);
         if (gguf_ctx) {
             ggml_free(ctx);
             gguf_free(gguf_ctx);
         }
-        fclose(file);
         printf("\n");
     }
 
+
     return std::make_pair(npass, ntest);
 }
 
@@ -789,10 +812,6 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
         const std::string key = "my_key_" + std::to_string(rng() % 1024);
         const enum gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
 
-        if (type == GGUF_TYPE_STRING || type == GGUF_TYPE_ARRAY) {
-            continue; // FIXME memory leak
-        }
-
         switch (type) {
             case GGUF_TYPE_UINT8:   gguf_set_val_u8  (gguf_ctx, key.c_str(), rng() % (1 <<  7));             break;
             case GGUF_TYPE_INT8:    gguf_set_val_i8  (gguf_ctx, key.c_str(), rng() % (1 <<  7) - (1 <<  6)); break;
@@ -826,6 +845,9 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
                         std::vector random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
                         for (size_t j = 0; j < random_data.size(); ++j) {
                             random_data[j] = rng();
+                            if (type_arr == GGUF_TYPE_BOOL) {
+                                random_data[j] &= 0x01010101; // the sanitizer complains if booleans are not 0 or 1
+                            }
                         }
                         gguf_set_arr_data(gguf_ctx, key.c_str(), type_arr, random_data.data(), ne);
                     } break;
@@ -928,6 +950,17 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
                 continue;
             }
 
+            if (type_arr == GGUF_TYPE_BOOL) {
+                const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+                const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+                for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data[arr_i]) != bool(data_other[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
             if (type_arr == GGUF_TYPE_STRING) {
                 for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
                     const std::string str       = gguf_get_arr_str(ctx,   id,       arr_i);
@@ -939,8 +972,8 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
                 continue;
             }
 
-            const char * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
-            const char * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+            const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+            const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
             if (!std::equal(data, data + arr_n*gguf_type_size(type_arr), data_other)) {
                 ok = false;
             }
@@ -1028,21 +1061,6 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
 }
 
 static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
-    FILE * file = tmpfile();
-#ifdef _WIN32
-    if (!file) {
-        printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
-        printf("%s: skipping tests");
-        return std::make_pair(0, 0);
-    }
-#else
-    GGML_ASSERT(file);
-#endif // _WIN32
-
-    if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
-        return std::make_pair(0, 0); // FIXME
-    }
-
     ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
     printf("%s: device=%s, backend=%s, only_meta=%s\n",
         __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no");
@@ -1060,10 +1078,24 @@ static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned
         bbuf       = result.buffer;
     }
 
-    struct gguf_buf gbuf = gguf_buf_init(16 * 1024);
-    gguf_write_to_buf(gguf_ctx_0, &gbuf, only_meta);
-    helper_write(gbuf.data, gbuf.offset, file);
-    rewind(file);
+    FILE * file = tmpfile();
+
+#ifdef _WIN32
+    if (!file) {
+        printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
+        printf("%s: skipping tests");
+        return std::make_pair(0, 0);
+    }
+#else
+    GGML_ASSERT(file);
+#endif // _WIN32
+
+    {
+        std::vector buf;
+        gguf_write_to_buf(gguf_ctx_0, buf, only_meta);
+        GGML_ASSERT(fwrite(buf.data(), 1, buf.size(), file) == buf.size());
+        rewind(file);
+    }
 
     struct ggml_context * ctx_1 = nullptr;
     struct gguf_init_params gguf_params = {
@@ -1151,9 +1183,8 @@ static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned
     ggml_free(ctx_1);
     gguf_free(gguf_ctx_0);
     gguf_free(gguf_ctx_1);
-    gguf_buf_free(gbuf);
     ggml_backend_free(backend);
-    GGML_ASSERT(fclose(file) == 0);
+    fclose(file);
 
     printf("\n");
     return std::make_pair(npass, ntest);
diff --git a/tests/test-model-load-cancel.cpp b/tests/test-model-load-cancel.cpp
index 858535c3c..9095826fa 100644
--- a/tests/test-model-load-cancel.cpp
+++ b/tests/test-model-load-cancel.cpp
@@ -21,7 +21,7 @@ int main(int argc, char *argv[] ) {
         (void) ctx;
         return progress > 0.50;
     };
-    auto * model = llama_load_model_from_file(model_path, params);
+    auto * model = llama_model_load_from_file(model_path, params);
     llama_backend_free();
     return model == nullptr ? EXIT_SUCCESS : EXIT_FAILURE;
 }
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
index 0af85f002..121c2c60c 100644
--- a/tests/test-tokenizer-0.cpp
+++ b/tests/test-tokenizer-0.cpp
@@ -152,7 +152,7 @@ int main(int argc, char **argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -165,7 +165,7 @@ int main(int argc, char **argv) {
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
@@ -300,7 +300,7 @@ int main(int argc, char **argv) {
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
index 0ff7fc833..5718fab04 100644
--- a/tests/test-tokenizer-1-bpe.cpp
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -46,7 +46,7 @@ int main(int argc, char **argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -59,7 +59,7 @@ int main(int argc, char **argv) {
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
@@ -143,7 +143,7 @@ int main(int argc, char **argv) {
         }
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
index 9b0716a43..ac05387c9 100644
--- a/tests/test-tokenizer-1-spm.cpp
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -34,7 +34,7 @@ int main(int argc, char ** argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();