diff --git a/common/common.cpp b/common/common.cpp index 1e5fc30dd..dbb724fbb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa if (arg == "--lora") { CHECK_ARG params.lora_adapter.emplace_back(argv[i], 1.0f); - params.use_mmap = false; return true; } if (arg == "--lora-scaled") { @@ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa const char* lora_adapter = argv[i]; CHECK_ARG params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); - params.use_mmap = false; return true; } if (arg == "--lora-base") { @@ -797,6 +795,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-nocb" || arg == "--no-cont-batching") { + params.cont_batching = false; + return true; + } if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; return true; @@ -1538,6 +1540,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel }); options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences }); options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" }); + options.push_back({ "*", "-nocb, --no-cont-batching", "disable continuous batching" }); options.push_back({ "multi-modality" }); options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); @@ -2084,19 +2087,14 @@ std::tuple llama_init_from_gpt_par for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { + auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); + if (adapter == nullptr) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); llama_free_model(model); return std::make_tuple(nullptr, nullptr); } + llama_lora_adapter_set(lctx, adapter, lora_scale); } if (params.ignore_eos) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bd208f405..852095f88 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2271,13 +2271,6 @@ class InternLM2Model(Model): special_vocab.add_to_gguf(self.gguf_writer) - def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - def set_gguf_parameters(self): self.gguf_writer.add_name("InternLM2") self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) @@ -2297,26 +2290,22 @@ class InternLM2Model(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] num_kv_heads = self.hparams["num_key_value_heads"] - hidden_size = self.hparams["hidden_size"] + n_embd = self.hparams["hidden_size"] q_per_kv = num_heads // num_kv_heads - head_dim = hidden_size // num_heads + head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv - qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv" - - if re.match(qkv_pattern, name): - bid = re.findall(qkv_pattern, name)[0] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch - # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) - qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim)) - q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] + + qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) + q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] + # The model weights of q and k equire additional reshape. - # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads) - q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads) - # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads) - k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads) - # v = rearrange(v, " o g n i -> o (g n i)").T - v = v.reshape((v.shape[0], -1)).T + q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) + k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) + v = v.reshape((-1, v.shape[-1])) + return [ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), @@ -3620,6 +3609,7 @@ def main() -> None: small_first_shard=args.no_tensor_first_split) logger.info("Set model parameters") + model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL) model_instance.set_gguf_parameters() logger.info("Set model tokenizer") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py new file mode 100755 index 000000000..4bb939d45 --- /dev/null +++ b/convert_lora_to_gguf.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from dataclasses import dataclass +import logging +import argparse +import os +import sys +import json +from math import prod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast + +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +# reuse model definitions from convert_hf_to_gguf.py +from convert_hf_to_gguf import LazyTorchTensor, Model + +logger = logging.getLogger("lora-to-gguf") + + +@dataclass +class PartialLoraTensor: + A: Tensor | None = None + B: Tensor | None = None + + +# magic to support tensor shape modifications and splitting +class LoraTorchTensor: + _lora_A: Tensor # (n_rank, row_size) + _lora_B: Tensor # (col_size, n_rank) + _rank: int + + def __init__(self, A: Tensor, B: Tensor): + assert len(A.shape) == len(B.shape) + assert A.shape[-2] == B.shape[-1] + if A.dtype != B.dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + self._lora_A = A + self._lora_B = B + self._rank = B.shape[-1] + + def get_lora_A_B(self) -> tuple[Tensor, Tensor]: + return (self._lora_A, self._lora_B) + + def __getitem__( + self, + indices: ( + SupportsIndex + | slice + | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature + ), + ) -> LoraTorchTensor: + shape = self.shape + if isinstance(indices, SupportsIndex): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + raise NotImplementedError # can't return a vector + elif isinstance(indices, slice): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + return LoraTorchTensor(self._lora_A, self._lora_B[indices]) + elif isinstance(indices, tuple): + assert len(indices) > 0 + if indices[-1] is Ellipsis: + return self[indices[:-1]] + # expand ellipsis + indices = tuple( + u + for v in ( + ( + (slice(None, None) for _ in range(len(indices) - 1)) + if i is Ellipsis + else (i,) + ) + for i in indices + ) + for u in v + ) + + if len(indices) < len(shape): + indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) + + # TODO: make sure this is correct + indices_A = ( + *( + ( + j.__index__() % self._lora_A.shape[i] + if isinstance(j, SupportsIndex) + else slice(None, None) + ) + for i, j in enumerate(indices[:-2]) + ), + slice(None, None), + indices[-1], + ) + indices_B = indices[:-1] + return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) + else: + raise NotImplementedError # unknown indice type + + @property + def dtype(self) -> torch.dtype: + assert self._lora_A.dtype == self._lora_B.dtype + return self._lora_A.dtype + + @property + def shape(self) -> tuple[int, ...]: + assert len(self._lora_A.shape) == len(self._lora_B.shape) + return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) + + def size(self, dim=None): + assert dim is None + return self.shape + + def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: + if isinstance(shape[0], tuple): + new_shape: tuple[int, ...] = shape[0] + else: + new_shape = cast(tuple[int, ...], shape) + orig_shape = self.shape + if len(new_shape) < 2: + raise NotImplementedError # can't become a vector + + # expand -1 in the shape + if any(dim == -1 for dim in new_shape): + n_elems = prod(orig_shape) + n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) + assert n_elems % n_new_elems == 0 + new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) + + if new_shape[-1] != orig_shape[-1]: + raise NotImplementedError # can't reshape the row size trivially + + shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1]) + shape_B = (*new_shape[:-1], self._rank) + return LoraTorchTensor( + self._lora_A.reshape(shape_A), + self._lora_B.reshape(shape_B), + ) + + def reshape_as(self, other: Tensor) -> LoraTorchTensor: + return self.reshape(*other.shape) + + def view(self, *size: int) -> LoraTorchTensor: + return self.reshape(*size) + + def permute(self, *dims: int) -> LoraTorchTensor: + shape = self.shape + dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) + if dims[-1] == -1: + # TODO: support higher dimensional A shapes bigger than 1 + assert all(dim == 1 for dim in self._lora_A.shape[:-2]) + return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) + if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: + return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + else: + # TODO: compose the above two + raise NotImplementedError + + def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: + shape = self.shape + dims = [i for i in range(len(shape))] + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + return self.permute(*dims) + + def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: + return self.transpose(axis0, axis1) + + def to(self, *args, **kwargs): + return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) + + @classmethod + def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.permute: + return type(args[0]).permute(*args, **kwargs) + elif func is torch.reshape: + return type(args[0]).reshape(*args, **kwargs) + elif func is torch.stack: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + return LoraTorchTensor( + torch.stack([a._lora_A for a in args[0]], dim), + torch.stack([b._lora_B for b in args[0]], dim), + ) + elif func is torch.cat: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + if len(args[0][0].shape) > 2: + return LoraTorchTensor( + torch.cat([a._lora_A for a in args[0]], dim), + torch.cat([b._lora_B for b in args[0]], dim), + ) + elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]): + return LoraTorchTensor( + args[0][0]._lora_A, + torch.cat([b._lora_B for b in args[0]], dim), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +def get_base_tensor_name(lora_tensor_name: str) -> str: + base_name = lora_tensor_name.replace("base_model.model.", "") + base_name = base_name.replace(".lora_A.weight", ".weight") + base_name = base_name.replace(".lora_B.weight", ".weight") + return base_name + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--base", type=Path, required=True, + help="directory containing base model file", + ) + parser.add_argument( + "lora_path", type=Path, + help="directory containing LoRA adapter file", + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + ftype = ftype_map[args.outtype] + + dir_base_model: Path = args.base + dir_lora: Path = args.lora_path + lora_config = dir_lora / "adapter_config.json" + input_model = dir_lora / "adapter_model.safetensors" + + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' + + if os.path.exists(input_model): + # lazy import load_file only if lora is in safetensors format. + from safetensors.torch import load_file + + lora_model = load_file(input_model, device="cpu") + else: + input_model = os.path.join(dir_lora, "adapter_model.bin") + lora_model = torch.load(input_model, map_location="cpu", weights_only=True) + + # load base model + logger.info(f"Loading base model: {dir_base_model.name}") + hparams = Model.load_hparams(dir_base_model) + with torch.inference_mode(): + try: + model_class = Model.from_model_architecture(hparams["architectures"][0]) + except NotImplementedError: + logger.error(f"Model {hparams['architectures'][0]} is not supported") + sys.exit(1) + + class LoraModel(model_class): + model_arch = model_class.model_arch + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_map: dict[str, PartialLoraTensor] = {} + + for name, tensor in lora_model.items(): + if self.lazy: + tensor = LazyTorchTensor.from_eager(tensor) + base_name = get_base_tensor_name(name) + is_lora_a = ".lora_A.weight" in name + is_lora_b = ".lora_B.weight" in name + if not is_lora_a and not is_lora_b: + if ".base_layer.weight" in name: + continue + logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + sys.exit(1) + + if base_name in tensor_map: + if is_lora_a: + tensor_map[base_name].A = tensor + else: + tensor_map[base_name].B = tensor + else: + if is_lora_a: + tensor_map[base_name] = PartialLoraTensor(A=tensor) + else: + tensor_map[base_name] = PartialLoraTensor(B=tensor) + + for name, tensor in tensor_map.items(): + assert tensor.A is not None + assert tensor.B is not None + yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + dest = super().modify_tensors(data_torch, name, bid) + for dest_name, dest_data in dest: + assert isinstance(dest_data, LoraTorchTensor) + lora_a, lora_b = dest_data.get_lora_A_B() + + yield (dest_name + ".lora_a", lora_a) + yield (dest_name + ".lora_b", lora_b) + + model_instance = LoraModel( + dir_base_model, + ftype, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=False, + eager=args.no_lazy, + model_name=None, + ) + + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + alpha = lparams["lora_alpha"] + + model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER) + model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") + model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha)) + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + logger.info("Exporting model...") + model_instance.write() + logger.info(f"Model successfully exported to {model_instance.fname_out}") diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 2712b66c1..04c5ccbbe 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -9,15 +9,15 @@ Adding a model requires few steps: After following these steps, you can open PR. Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially: -- [main](../examples/main) -- [imatrix](../examples/imatrix) -- [quantize](../examples/quantize) -- [server](../examples/server) +- [main](/examples/main/) +- [imatrix](/examples/imatrix/) +- [quantize](/examples/quantize/) +- [server](/examples/server/) ### 1. Convert the model to GGUF This step is done in python with a `convert` script using the [gguf](https://pypi.org/project/gguf/) library. -Depending on the model architecture, you can use either [convert_hf_to_gguf.py](../convert_hf_to_gguf.py) or [examples/convert_legacy_llama.py](../examples/convert_legacy_llama.py) (for `llama/llama2` models in `.pth` format). +Depending on the model architecture, you can use either [convert_hf_to_gguf.py](/convert_hf_to_gguf.py) or [examples/convert_legacy_llama.py](/examples/convert_legacy_llama.py) (for `llama/llama2` models in `.pth` format). The convert script reads the model configuration, tokenizer, tensor names+data and converts them to GGUF metadata and tensors. @@ -31,7 +31,7 @@ class MyModel(Model): model_arch = gguf.MODEL_ARCH.GROK ``` -2. Define the layout of the GGUF tensors in [constants.py](../gguf-py/gguf/constants.py) +2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py) Add an enum entry in `MODEL_ARCH`, the model human friendly name in `MODEL_ARCH_NAMES` and the GGUF tensor names in `MODEL_TENSORS`. @@ -54,7 +54,7 @@ Example for `falcon` model: As a general rule, before adding a new tensor name to GGUF, be sure the equivalent naming does not already exist. -Once you have found the GGUF tensor name equivalent, add it to the [tensor_mapping.py](../gguf-py/gguf/tensor_mapping.py) file. +Once you have found the GGUF tensor name equivalent, add it to the [tensor_mapping.py](/gguf-py/gguf/tensor_mapping.py) file. If the tensor name is part of a repetitive layer/block, the key word `bid` substitutes it. @@ -100,7 +100,7 @@ Have a look at existing implementation like `build_llama`, `build_dbrx` or `buil When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. -Note: to debug the inference graph: you can use [llama-eval-callback](../examples/eval-callback). +Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/). ## GGUF specification diff --git a/docs/development/token_generation_performance_tips.md b/docs/development/token_generation_performance_tips.md index c0840cad5..41b7232c9 100644 --- a/docs/development/token_generation_performance_tips.md +++ b/docs/development/token_generation_performance_tips.md @@ -1,7 +1,7 @@ # Token generation performance troubleshooting ## Verifying that the model is running on the GPU with CUDA -Make sure you compiled llama with the correct env variables according to [this guide](../README.md#CUDA), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example: +Make sure you compiled llama with the correct env variables according to [this guide](/docs/build.md#cuda), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example: ```shell ./llama-cli -m "path/to/model.gguf" -ngl 200000 -p "Please sir, may I have some " ``` diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py index d8145710c..93e5dcb6c 100644 --- a/examples/pydantic_models_to_grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -6,7 +6,7 @@ import re from copy import copy from enum import Enum from inspect import getdoc, isclass -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints from docstring_parser import parse from pydantic import BaseModel, create_model @@ -53,35 +53,38 @@ class PydanticDataType(Enum): def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: - if isclass(pydantic_type) and issubclass(pydantic_type, str): + origin_type = get_origin(pydantic_type) + origin_type = pydantic_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, str): return PydanticDataType.STRING.value - elif isclass(pydantic_type) and issubclass(pydantic_type, bool): + elif isclass(origin_type) and issubclass(origin_type, bool): return PydanticDataType.BOOLEAN.value - elif isclass(pydantic_type) and issubclass(pydantic_type, int): + elif isclass(origin_type) and issubclass(origin_type, int): return PydanticDataType.INTEGER.value - elif isclass(pydantic_type) and issubclass(pydantic_type, float): + elif isclass(origin_type) and issubclass(origin_type, float): return PydanticDataType.FLOAT.value - elif isclass(pydantic_type) and issubclass(pydantic_type, Enum): + elif isclass(origin_type) and issubclass(origin_type, Enum): return PydanticDataType.ENUM.value - elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel): - return format_model_and_field_name(pydantic_type.__name__) - elif get_origin(pydantic_type) is list: + elif isclass(origin_type) and issubclass(origin_type, BaseModel): + return format_model_and_field_name(origin_type.__name__) + elif origin_type is list: element_type = get_args(pydantic_type)[0] return f"{map_pydantic_type_to_gbnf(element_type)}-list" - elif get_origin(pydantic_type) is set: + elif origin_type is set: element_type = get_args(pydantic_type)[0] return f"{map_pydantic_type_to_gbnf(element_type)}-set" - elif get_origin(pydantic_type) is Union: + elif origin_type is Union: union_types = get_args(pydantic_type) union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types] return f"union-{'-or-'.join(union_rules)}" - elif get_origin(pydantic_type) is Optional: + elif origin_type is Optional: element_type = get_args(pydantic_type)[0] return f"optional-{map_pydantic_type_to_gbnf(element_type)}" - elif isclass(pydantic_type): - return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}" - elif get_origin(pydantic_type) is dict: + elif isclass(origin_type): + return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}" + elif origin_type is dict: key_type, value_type = get_args(pydantic_type) return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" else: @@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name): # Modify this comprehension members = [ f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}' - for name, param_type in cls.__annotations__.items() + for name, param_type in get_type_hints(cls).items() if name != "self" ] @@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type( field_name = format_model_and_field_name(field_name) gbnf_type = map_pydantic_type_to_gbnf(field_type) - if isclass(field_type) and issubclass(field_type, BaseModel): + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, BaseModel): nested_model_name = format_model_and_field_name(field_type.__name__) nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) rules.extend(nested_model_rules) gbnf_type, rules = nested_model_name, rules - elif isclass(field_type) and issubclass(field_type, Enum): + elif isclass(origin_type) and issubclass(origin_type, Enum): enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" rules.append(enum_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == list: # Array + elif origin_type is list: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules @@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type( rules.append(array_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == set or field_type == set: # Array + elif origin_type is set: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules @@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type( gbnf_type = f"{model_name}-{field_name}-optional" else: gbnf_type = f"{model_name}-{field_name}-union" - elif isclass(field_type) and issubclass(field_type, str): + elif isclass(origin_type) and issubclass(origin_type, str): if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) @@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type( gbnf_type = PydanticDataType.STRING.value elif ( - isclass(field_type) - and issubclass(field_type, float) + isclass(origin_type) + and issubclass(origin_type, float) and field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None @@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type( ) elif ( - isclass(field_type) - and issubclass(field_type, int) + isclass(origin_type) + and issubclass(origin_type, int) and field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None @@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas if not issubclass(model, BaseModel): # For non-Pydantic classes, generate model_fields from __annotations__ or __init__ if hasattr(model, "__annotations__") and model.__annotations__: - model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues] + model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()} else: init_signature = inspect.signature(model.__init__) parameters = init_signature.parameters @@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas name != "self"} else: # For Pydantic models, use model_fields and check for ellipsis (required fields) - model_fields = model.__annotations__ + model_fields = get_type_hints(model) model_rule_parts = [] nested_rules = [] @@ -706,7 +712,7 @@ def generate_markdown_documentation( else: documentation += f" Fields:\n" # noqa: F541 if isclass(model) and issubclass(model, BaseModel): - for name, field_type in model.__annotations__.items(): + for name, field_type in get_type_hints(model).items(): # if name == "markdown_code_block": # continue if get_origin(field_type) == list: @@ -754,14 +760,17 @@ def generate_field_markdown( field_info = model.model_fields.get(field_name) field_description = field_info.description if field_info and field_info.description else "" - if get_origin(field_type) == list: + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if origin_type == list: element_type = get_args(field_type)[0] field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" if field_description != "": field_text += ":\n" else: field_text += "\n" - elif get_origin(field_type) == Union: + elif origin_type == Union: element_types = get_args(field_type) types = [] for element_type in element_types: @@ -792,9 +801,9 @@ def generate_field_markdown( example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example field_text += f"{indent} Example: {example_text}\n" - if isclass(field_type) and issubclass(field_type, BaseModel): + if isclass(origin_type) and issubclass(origin_type, BaseModel): field_text += f"{indent} Details:\n" - for name, type_ in field_type.__annotations__.items(): + for name, type_ in get_type_hints(field_type).items(): field_text += generate_field_markdown(name, type_, field_type, depth + 2) return field_text @@ -855,7 +864,7 @@ def generate_text_documentation( if isclass(model) and issubclass(model, BaseModel): documentation_fields = "" - for name, field_type in model.__annotations__.items(): + for name, field_type in get_type_hints(model).items(): # if name == "markdown_code_block": # continue if get_origin(field_type) == list: @@ -948,7 +957,7 @@ def generate_field_text( if isclass(field_type) and issubclass(field_type, BaseModel): field_text += f"{indent} Details:\n" - for name, type_ in field_type.__annotations__.items(): + for name, type_ in get_type_hints(field_type).items(): field_text += generate_field_text(name, type_, field_type, depth + 2) return field_text diff --git a/examples/pydantic_models_to_grammar_examples.py b/examples/pydantic_models_to_grammar_examples.py index 8e7f46cf9..504ed98df 100644 --- a/examples/pydantic_models_to_grammar_examples.py +++ b/examples/pydantic_models_to_grammar_examples.py @@ -20,6 +20,8 @@ def create_completion(prompt, grammar): response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data) data = response.json() + assert data.get("error") is None, data + print(data["content"]) return data["content"] diff --git a/examples/server/README.md b/examples/server/README.md index cb45ee06d..e477d1501 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -15,69 +15,281 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggerganov/llama.cpp/issues/4216). -**Command line options:** +## Usage -- `-v`, `--verbose`: Enable verbose server output. When using the `/completion` endpoint, this includes the tokenized prompt, the full request and the full response. -- `-t N`, `--threads N`: Set the number of threads to use by CPU layers during generation. Not used by model layers that are offloaded to GPU. This option has no effect when using the maximum number of GPU layers. Default: `std::thread::hardware_concurrency()` (number of CPU cores). -- `-tb N, --threads-batch N`: Set the number of threads to use by CPU layers during batch and prompt processing (>= 32 tokens). This option has no effect if a GPU is available. Default: `--threads`. -- `--threads-http N`: Number of threads in the http server pool to process requests. Default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)` -- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). -- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file. Default: unused -- `-hfr REPO, --hf-repo REPO`: Hugging Face model repository. Default: unused -- `-hff FILE, --hf-file FILE`: Hugging Face model file. Default: unused -- `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. -- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is `512`, but LLaMA models were built with a context of `2048`, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of `4096`. -- `-ngl N`, `--n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance. -- `-mg i, --main-gpu i`: When using multiple GPUs, this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default, GPU `0` is used. -- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs, this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default, the data is split in proportion to VRAM, but this may not be optimal for performance. -- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `2048` -- `-ub N`, `--ubatch-size N`: Physical maximum batch size. Default: `512` -- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. -- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. -- `--numa STRATEGY`: Attempt one of the below optimization strategies that may help on some NUMA systems -- `--numa distribute`: Spread execution evenly over all nodes -- `--numa isolate`: Only spawn threads on CPUs on the node that execution started on -- `--numa numactl`: Use the CPU map provided by numactl. If run without this previously, it is recommended to drop the system page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/1437 -- `--numa`: Attempt optimizations that may help on some NUMA systems. -- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. -- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. -- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600` -- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1` -- `--port`: Set the port to listen. Default: `8080` -- `--path`: Path from which to serve static files. Default: disabled -- `--api-key`: Set an api key for request authorization. By default, the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys. -- `--api-key-file`: Path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`s. -- `--embeddings`: Enable embedding vector output and the OAI compatible endpoint /v1/embeddings. Physical batch size (`--ubatch-size`) must be carefully defined. Default: disabled -- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1`. Values > 1 will allow for higher throughput with multiple parallel requests but the results will **not** be deterministic due to differences in rounding error. -- `-cb`, `--cont-batching`: Enable continuous batching (a.k.a dynamic batching). Default: disabled -- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load a system prompt (initial prompt of all slots). This is useful for chat applications. [See more](#change-system-prompt-on-runtime) -- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. -- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend. Used together with group attention width `--grp-attn-w`. Default: `1`, which is disabled. -- `--grp-attn-w`: Set the group attention width to extend context size through self-extend. Used together with group attention factor `--grp-attn-n`. Default: `512` -- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1` -- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. -- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled -- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled. -- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) -- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled -- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` -- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn` -- `--rope-freq-base N` : RoPE frequency base (default: loaded from model) -- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25) -- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation) -- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0) -- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0) -- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0) -- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls` -- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled) -- `-fa`, `--flash-attn` : enable flash attention (default: disabled). -- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`) -- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options) -- `--spm-infill` : Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. +``` +usage: ./llama-server [options] + +general: + + -h, --help, --usage print usage and exit + --version show version and build info + -v, --verbose print verbose information + --verbosity N set specific verbosity level (default: 0) + --verbose-prompt print a verbose prompt before generation (default: false) + --no-display-prompt don't print prompt at generation (default: false) + -co, --color colorise output to distinguish prompt and user input from generations (default: false) + -s, --seed SEED RNG seed (default: -1, use random seed for < 0) + -t, --threads N number of threads to use during generation (default: 8) + -tb, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads) + -td, --threads-draft N number of threads to use during generation (default: same as --threads) + -tbd, --threads-batch-draft N number of threads to use during batch and prompt processing (default: same as --threads-draft) + --draft N number of tokens to draft for speculative decoding (default: 5) + -ps, --p-split N speculative decoding split probability (default: 0.1) + -lcs, --lookup-cache-static FNAME + path to static lookup cache to use for lookup decoding (not updated by generation) + -lcd, --lookup-cache-dynamic FNAME + path to dynamic lookup cache to use for lookup decoding (updated by generation) + -c, --ctx-size N size of the prompt context (default: 0, 0 = loaded from model) + -n, --predict N number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled) + -b, --batch-size N logical maximum batch size (default: 2048) + -ub, --ubatch-size N physical maximum batch size (default: 512) + --keep N number of tokens to keep from the initial prompt (default: 0, -1 = all) + --chunks N max number of chunks to process (default: -1, -1 = all) + -fa, --flash-attn enable Flash Attention (default: disabled) + -p, --prompt PROMPT prompt to start generation with + in conversation mode, this will be used as system prompt + (default: '') + -f, --file FNAME a file containing the prompt (default: none) + --in-file FNAME an input file (repeat to specify multiple files) + -bf, --binary-file FNAME binary file containing the prompt (default: none) + -e, --escape process escapes sequences (\n, \r, \t, \', \", \\) (default: true) + --no-escape do not process escape sequences + -ptc, --print-token-count N print token count every N tokens (default: -1) + --prompt-cache FNAME file to cache prompt state for faster startup (default: none) + --prompt-cache-all if specified, saves user input and generations to cache as well + not supported with --interactive or other interactive options + --prompt-cache-ro if specified, uses the prompt cache but does not update it + -r, --reverse-prompt PROMPT halt generation at PROMPT, return control in interactive mode + can be specified more than once for multiple prompts + -sp, --special special tokens output enabled (default: false) + -cnv, --conversation run in conversation mode, does not print special tokens and suffix/prefix + if suffix/prefix are not specified, default chat template will be used + (default: false) + -i, --interactive run in interactive mode (default: false) + -if, --interactive-first run in interactive mode and wait for input right away (default: false) + -mli, --multiline-input allows you to write or paste multiple lines without ending each in '\' + --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string + --in-prefix STRING string to prefix user inputs with (default: empty) + --in-suffix STRING string to suffix after user inputs with (default: empty) + --spm-infill use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) + +sampling: + + --samplers SAMPLERS samplers that will be used for generation in the order, separated by ';' + (default: top_k;tfs_z;typical_p;top_p;min_p;temperature) + --sampling-seq SEQUENCE simplified sequence for samplers that will be used (default: kfypmt) + --ignore-eos ignore end of stream token and continue generating (implies --logit-bias EOS-inf) + --penalize-nl penalize newline tokens (default: false) + --temp N temperature (default: 0.8) + --top-k N top-k sampling (default: 40, 0 = disabled) + --top-p N top-p sampling (default: 0.9, 1.0 = disabled) + --min-p N min-p sampling (default: 0.1, 0.0 = disabled) + --tfs N tail free sampling, parameter z (default: 1.0, 1.0 = disabled) + --typical N locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) + --repeat-last-n N last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) + --repeat-penalty N penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) + --presence-penalty N repeat alpha presence penalty (default: 0.0, 0.0 = disabled) + --frequency-penalty N repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) + --dynatemp-range N dynamic temperature range (default: 0.0, 0.0 = disabled) + --dynatemp-exp N dynamic temperature exponent (default: 1.0) + --mirostat N use Mirostat sampling. + Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used. + (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + --mirostat-lr N Mirostat learning rate, parameter eta (default: 0.1) + --mirostat-ent N Mirostat target entropy, parameter tau (default: 5.0) + -l TOKEN_ID(+/-)BIAS modifies the likelihood of token appearing in the completion, + i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello', + or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' + --cfg-negative-prompt PROMPT + negative prompt to use for guidance (default: '') + --cfg-negative-prompt-file FNAME + negative prompt file to use for guidance + --cfg-scale N strength of guidance (default: 1.0, 1.0 = disable) + --chat-template JINJA_TEMPLATE + set custom jinja chat template (default: template taken from model's metadata) + if suffix/prefix are specified, template will be disabled + only commonly used templates are accepted: + https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + +grammar: + + --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') + --grammar-file FNAME file to read grammar from + -j, --json-schema SCHEMA JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object + For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead + +embedding: + + --pooling {none,mean,cls,last} + pooling type for embeddings, use model default if unspecified + --attention {causal,non-causal} + attention type for embeddings, use model default if unspecified + +context hacking: + + --rope-scaling {none,linear,yarn} + RoPE frequency scaling method, defaults to linear unless specified by the model + --rope-scale N RoPE context scaling factor, expands context by a factor of N + --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model) + --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N + --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size) + --yarn-ext-factor N YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation) + --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0) + --yarn-beta-slow N YaRN: high correction dim or alpha (default: 1.0) + --yarn-beta-fast N YaRN: low correction dim or beta (default: 32.0) + -gan, --grp-attn-n N group-attention factor (default: 1) + -gaw, --grp-attn-w N group-attention width (default: 512.0) + -dkvc, --dump-kv-cache verbose print of the KV cache + -nkvo, --no-kv-offload disable KV offload + -ctk, --cache-type-k TYPE KV cache data type for K (default: f16) + -ctv, --cache-type-v TYPE KV cache data type for V (default: f16) + +perplexity: + + --all-logits return logits for all tokens in the batch (default: false) + --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f + --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: 400) + --winogrande compute Winogrande score over random tasks from datafile supplied with -f + --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: 0) + --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f + --multiple-choice-tasks N + number of tasks to use when computing the multiple choice score (default: 0) + --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base + --ppl-stride N stride for perplexity calculation (default: 0) + --ppl-output-type {0,1} output type for perplexity calculation (default: 0) + +parallel: + + -dt, --defrag-thold N KV cache defragmentation threshold (default: -1.0, < 0 - disabled) + -np, --parallel N number of parallel sequences to decode (default: 1) + -ns, --sequences N number of sequences to decode (default: 1) + -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled) + +multi-modality: + + --mmproj FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md + --image FILE path to an image file. use with multimodal models. Specify multiple times for batching + +backend: + + --rpc SERVERS comma separated list of RPC servers + --mlock force system to keep model in RAM rather than swapping or compressing + --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock) + --numa TYPE attempt optimizations that help on some NUMA systems + - distribute: spread execution evenly over all nodes + - isolate: only spawn threads on CPUs on the node that execution started on + - numactl: use the CPU map provided by numactl + if run without this previously, it is recommended to drop the system page cache before using this + see https://github.com/ggerganov/llama.cpp/issues/1437 + +model: + + --check-tensors check model tensor data for invalid values (default: false) + --override-kv KEY=TYPE:VALUE + advanced option to override model metadata by key. may be specified multiple times. + types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false + --lora FNAME apply LoRA adapter (implies --no-mmap) + --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap) + --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter + --control-vector FNAME add a control vector + note: this argument can be repeated to add multiple control vectors + --control-vector-scaled FNAME SCALE + add a control vector with user defined scaling SCALE + note: this argument can be repeated to add multiple scaled control vectors + --control-vector-layer-range START END + layer range to apply the control vector(s) to, start and end inclusive + -m, --model FNAME model path (default: models/$filename with filename from --hf-file + or --model-url if set, otherwise models/7B/ggml-model-f16.gguf) + -md, --model-draft FNAME draft model for speculative decoding (default: unused) + -mu, --model-url MODEL_URL model download url (default: unused) + -hfr, --hf-repo REPO Hugging Face model repository (default: unused) + -hff, --hf-file FILE Hugging Face model file (default: unused) + -hft, --hf-token TOKEN Hugging Face access token (default: value from HF_TOKEN environment variable) + +retrieval: + + --context-file FNAME file to load context from (repeat to specify multiple files) + --chunk-size N minimum length of embedded text chunks (default: 64) + --chunk-separator STRING + separator between chunks (default: ' + ') + +passkey: + + --junk N number of times to repeat the junk text (default: 250) + --pos N position of the passkey in the junk text (default: -1) + +imatrix: + + -o, --output FNAME output file (default: 'imatrix.dat') + --output-frequency N output the imatrix every N iterations (default: 10) + --save-frequency N save an imatrix copy every N iterations (default: 0) + --process-output collect data for the output tensor (default: false) + --no-ppl do not compute perplexity (default: true) + --chunk N start processing the input from chunk N (default: 0) + +bench: + + -pps is the prompt shared across parallel sequences (default: false) + -npp n0,n1,... number of prompt tokens + -ntg n0,n1,... number of text generation tokens + -npl n0,n1,... number of parallel prompts + +embedding: + + --embd-normalize normalisation for embendings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + --embd-output-format empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + --embd-separator separator of embendings (default \n) for example "<#sep#>" + +server: + + --host HOST ip address to listen (default: 127.0.0.1) + --port PORT port to listen (default: 8080) + --path PATH path to serve static files from (default: ) + --embedding(s) enable embedding endpoint (default: disabled) + --api-key KEY API key to use for authentication (default: none) + --api-key-file FNAME path to file containing API keys (default: none) + --ssl-key-file FNAME path to file a PEM-encoded SSL private key + --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate + --timeout N server read/write timeout in seconds (default: 600) + --threads-http N number of threads used to process HTTP requests (default: -1) + --system-prompt-file FNAME + set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications + --log-format {text,json} + log output format: json or text (default: json) + --metrics enable prometheus compatible metrics endpoint (default: disabled) + --no-slots disables slots monitoring endpoint (default: enabled) + --slot-save-path PATH path to save slot kv cache (default: disabled) + --chat-template JINJA_TEMPLATE + set custom jinja chat template (default: template taken from model's metadata) + only commonly used templates are accepted: + https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + -sps, --slot-prompt-similarity SIMILARITY + how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled) + + +logging: + + --simple-io use basic IO for better compatibility in subprocesses and limited consoles + -ld, --logdir LOGDIR path under which to save YAML logs (no logging if unset) + --log-test Run simple logging test + --log-disable Disable trace logs + --log-enable Enable trace logs + --log-file FNAME Specify a log filename (without extension) + --log-new Create a separate new log file on start. Each log file will have unique name: "..log" + --log-append Don't truncate the old log file. + +cvector: + + -o, --output FNAME output file (default: 'control_vector.gguf') + --positive-file FNAME positive prompts file, one prompt per line (default: 'examples/cvector-generator/positive.txt') + --negative-file FNAME negative prompts file, one prompt per line (default: 'examples/cvector-generator/negative.txt') + --pca-batch N batch size used for PCA. Larger batch runs faster, but uses more memory (default: 100) + --pca-iter N number of iterations used for PCA (default: 1000) + --method {pca,mean} dimensionality reduction method to be used (default: pca) +``` -**If compiled with `LLAMA_SERVER_SSL=ON`** -- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key -- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate ## Build diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 40838cf4f..26535b1c4 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -14,7 +14,9 @@ #include "ggml-aarch64.h" +#if defined(__GNUC__) #pragma GCC diagnostic ignored "-Woverlength-strings" +#endif #define UNUSED GGML_UNUSED diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ed784ea1c..39e345b66 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2 + && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 5a890237f..36518ff93 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -291,29 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k, dst[i] = x[i] * x[i]; } -static void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02, - const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (nidx >= ne0) { - return; - } - // operation - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (item_ct1.get_group(0) < ne02) { // src0 - int offset_src = - nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + item_ct1.get_group(1) * ne0 + - (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1); - dst[offset_dst] = y[offset_src]; - } -} - static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, @@ -1347,20 +1324,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k, }); } -static void concat_f32_sycl(const float *x, const float *y, float *dst, - const int ne0, int ne1, int ne2, int ne02, - queue_ptr stream) { - int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; - sycl::range<3> gridDim(ne2, ne1, num_blocks); - stream->parallel_for( - sycl::nd_range<3>(gridDim * - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - concat_f32(x, y, dst, ne0, ne02, item_ct1); - }); -} - static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, @@ -2429,28 +2392,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor (void) src1_dd; } -inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { -#pragma message("TODO: generalize concat kernel for dim != 2") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563") - int dim = dst->op_params[0]; - GGML_ASSERT(dim == 2); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_sycl(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream); - } - - (void) src1; - (void) dst; -} - inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -3359,12 +3300,6 @@ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_ten GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale); @@ -4101,7 +4036,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens func = ggml_sycl_group_norm; break; case GGML_OP_CONCAT: - func = ggml_sycl_concat; + func = ggml_sycl_op_concat; break; case GGML_OP_UPSCALE: func = ggml_sycl_upscale; diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 2a789edfc..067181de3 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -13,6 +13,7 @@ #ifndef GGML_SYCL_BACKEND_HPP #define GGML_SYCL_BACKEND_HPP +#include "concat.hpp" #include "common.hpp" #include "convert.hpp" #include "dequantize.hpp" diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp new file mode 100644 index 000000000..632eedb9d --- /dev/null +++ b/ggml/src/ggml-sycl/concat.cpp @@ -0,0 +1,195 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "concat.hpp" +#include "common.hpp" + +static void concat_f32_dim0(const float *x, const float *y, float *dst, + const int ne0, const int ne00, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (nidx < ne00) { // src0 + int offset_src = nidx + item_ct1.get_group(1) * ne00 + + item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1); + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) + + item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_dim1(const float *x, const float *y, float *dst, + const int ne0, const int ne01, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (item_ct1.get_group(1) < ne01) { // src0 + int offset_src = + nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + (item_ct1.get_group(1) - ne01) * ne0 + + item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_dim2(const float *x, const float *y, float *dst, + const int ne0, const int ne02, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (item_ct1.get_group(0) < ne02) { // src0 + int offset_src = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + item_ct1.get_group(1) * ne0 + + (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1); + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_sycl(const float *x, const float *y, float *dst, + int ne00, int ne01, int ne02, int ne0, int ne1, + int ne2, int dim, queue_ptr stream) { + int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; + sycl::range<3> gridDim(ne2, ne1, num_blocks); + switch (dim) { + case 0: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); + }); + break; + case 1: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); + }); + break; + default: + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); + }); + break; + } +} + +// non-contiguous kernel (slow) +static void concat_f32_sycl_non_cont( + queue_ptr stream, const char *src0, const char *src1, char *dst, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00, + uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/, + int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10, + uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1, + int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2, + uint64_t nb3, int32_t dim) { + sycl::range<3> gridDim(ne3, ne2, ne1); + stream->parallel_for( + sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + int64_t i3 = item_ct1.get_group(0); + int64_t i2 = item_ct1.get_group(1); + int64_t i1 = item_ct1.get_group(2); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + const float *x; + + for (int i0 = item_ct1.get_local_id(2); i0 < ne0; + i0 += item_ct1.get_local_range(2)) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 + + (i0)*nb00); + } else { + x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + + (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10); + } + + float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + + *y = *x; + } + }); +} + +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst) { + queue_ptr stream = ctx.stream(); + + const int32_t dim = ((int32_t *)dst->op_params)[0]; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float *src0_d = (const float *)src0->data; + const float *src1_d = (const float *)src1->data; + + float *dst_d = (float *)dst->data; + + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_sycl( + src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], + src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + } + } else + concat_f32_sycl_non_cont( + stream, (const char *)src0->data, (const char *)src1->data, + (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], + src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); +} diff --git a/ggml/src/ggml-sycl/concat.hpp b/ggml/src/ggml-sycl/concat.hpp new file mode 100644 index 000000000..5a04feaab --- /dev/null +++ b/ggml/src/ggml-sycl/concat.hpp @@ -0,0 +1,21 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_CONCAT_HPP +#define GGML_SYCL_CONCAT_HPP + +#include "common.hpp" + +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst); + +#endif // GGML_SYCL_CONCAT_HPP diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 101781ede..8efe32329 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -6561,7 +6561,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); + ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; @@ -6645,7 +6645,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * for (int i3 = 0; i3 < src0->ne[3]; i3++) { for (int i2 = 0; i2 < src0->ne[2]; i2++) { const int idx = i3*src0->ne[2] + i2; - ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); + ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); } } @@ -6658,7 +6658,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * if (offset + src0_size >= buffer_gpu->size) { src0_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size); + ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { @@ -6687,7 +6687,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * for (int i3 = 0; i3 < src1->ne[3]; i3++) { for (int i2 = 0; i2 < src1->ne[2]; i2++) { const int idx = i3*src1->ne[2] + i2; - ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); + ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); } } @@ -6700,7 +6700,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * if (offset + src1_size >= buffer_gpu->size) { src1_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size); + ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { @@ -6745,7 +6745,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * for (int i3 = 0; i3 < src2->ne[3]; i3++) { for (int i2 = 0; i2 < src2->ne[2]; i2++) { const int idx = i3*src2->ne[2] + i2; - ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); + ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); } } @@ -6758,7 +6758,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * if (offset + src2_size >= buffer_gpu->size) { src2_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(ctx, buffer_gpu, offset, src2_clone->data, src2_size); + ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { @@ -6922,7 +6922,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs); } - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); + ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } float first_error_result = -1.0f; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9a5414787..60b3c5e7a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); + fprintf(fp, " rankdir = TB;\n"); for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; @@ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph } fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (ggml_nelements(node) < 5) { + if (ggml_nelements(node) < 5 && node->data != NULL) { fprintf(fp, " | ("); for (int j = 0; j < ggml_nelements(node); j++) { if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index fa231c0ec..3038d647f 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -270,10 +270,10 @@ void matmul_shaders(std::vector>& tasks, bool fp16, bool matmu std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); + string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); })); tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "2"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); })); } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a95a44237..5eb3df706 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -19,6 +19,7 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h class Keys: class General: + TYPE = "general.type" ARCHITECTURE = "general.architecture" QUANTIZATION_VERSION = "general.quantization_version" ALIGNMENT = "general.alignment" @@ -120,11 +121,20 @@ class Keys: MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" + class Adapter: + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + # # recommended mapping of model tensor names for storage in gguf # +class GGUFType: + MODEL = "model" + ADAPTER = "adapter" + + class MODEL_ARCH(IntEnum): LLAMA = auto() FALCON = auto() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index cf9554162..b0197961d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -424,6 +424,9 @@ class GGUFWriter: fout.close() self.fout = None + def add_type(self, type_name: str) -> None: + self.add_string(Keys.General.TYPE, type_name) + def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index b22eec166..16e0a9aaa 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np. osize *= dim out = np.empty(shape=osize, dtype=otype) # compute over groups of 16 rows (arbitrary, but seems good for performance) - n_groups = rows.shape[0] // 16 + n_groups = (rows.shape[0] // 16) or 1 np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) return out.reshape(oshape) diff --git a/include/llama.h b/include/llama.h index 3970c3aeb..c57d21f0c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -411,6 +411,9 @@ extern "C" { const char * content; } llama_chat_message; + // lora adapter + struct llama_lora_adapter; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); @@ -510,18 +513,28 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); - // Apply a LoRA adapter to a loaded model - // path_base_model is the path to a higher quality model to use as a base for - // the layers modified by the adapter. Can be NULL to use the current loaded model. - // The model needs to be reloaded before applying a new adapter, otherwise the adapter - // will be applied on top of the previous one - // Returns 0 on success - LLAMA_API int32_t llama_model_apply_lora_from_file( - const struct llama_model * model, - const char * path_lora, - float scale, - const char * path_base_model, - int32_t n_threads); + // Load a LoRA adapter from file + // The loaded adapter will be associated to the given model, and will be free when the model is deleted + LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( + struct llama_model * model, + const char * path_lora); + + // Add a loaded LoRA adapter to given context + // This will not modify model's weight + LLAMA_API int32_t llama_lora_adapter_set( + struct llama_context * ctx, + struct llama_lora_adapter * adapter, + float scale); + + // Remove a LoRA adapter from given context + // Return -1 if the adapter is not present in the context + LLAMA_API int32_t llama_lora_adapter_remove( + struct llama_context * ctx, + struct llama_lora_adapter * adapter); + + // Manually free a LoRA adapter + // Note: loaded adapters will be free when the associated model is deleted + LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. diff --git a/requirements.txt b/requirements.txt index 52456c2e6..9e190ae27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt +-r ./requirements/requirements-convert_lora_to_gguf.txt diff --git a/requirements/requirements-convert_lora_to_gguf.txt b/requirements/requirements-convert_lora_to_gguf.txt new file mode 100644 index 000000000..5758076c4 --- /dev/null +++ b/requirements/requirements-convert_lora_to_gguf.txt @@ -0,0 +1,2 @@ +-r ./requirements-convert_hf_to_gguf.txt +--extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements/requirements-pydantic.txt b/requirements/requirements-pydantic.txt index 2f9455b14..bdd423e07 100644 --- a/requirements/requirements-pydantic.txt +++ b/requirements/requirements-pydantic.txt @@ -1,2 +1,3 @@ docstring_parser~=0.15 pydantic~=2.6.3 +requests diff --git a/src/llama.cpp b/src/llama.cpp index 400a4232b..07bb42713 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -287,6 +287,7 @@ static const std::map LLM_ARCH_NAMES = { }; enum llm_kv { + LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, @@ -377,9 +378,13 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, }; static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, @@ -470,6 +475,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, }; struct LLM_KV { @@ -2703,6 +2711,9 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; + // keep track of loaded lora adapters + std::set lora_adapters; + ~llama_model() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2715,6 +2726,9 @@ struct llama_model { #endif ggml_backend_buffer_free(buf); } + while (!lora_adapters.empty()) { + llama_lora_adapter_free(*lora_adapters.begin()); + } } }; @@ -2819,6 +2833,52 @@ struct llama_context { // control vectors struct llama_control_vector cvec; + + // lora adapters and scales + std::unordered_map lora_adapters; +}; + +struct llama_lora_weight { + struct ggml_tensor * a = nullptr; + struct ggml_tensor * b = nullptr; + llama_lora_weight() = default; + llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} +}; + +struct llama_lora_adapter { + struct llama_model * base_model; + // map tensor name to lora_a_b + std::unordered_map ab_map; + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_lora_adapter(struct llama_model * base_model): base_model(base_model) { + base_model->lora_adapters.insert(this); + } + + llama_lora_weight * get_weight(struct ggml_tensor * w) { + std::string name(w->name); + auto pos = ab_map.find(name); + if (ab_map.find(name) != ab_map.end()) { + return &pos->second; + } + return nullptr; + } + + ~llama_lora_adapter() { + for (struct ggml_context * ctx : ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } + auto pos = base_model->lora_adapters.find(this); + if (pos != base_model->lora_adapters.end()) { + base_model->lora_adapters.erase(pos); + } + } }; static size_t llama_get_device_count(const llama_model & model) { @@ -7809,6 +7869,58 @@ static void llm_build_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } +// do mat_mul, while optionally apply lora +static struct ggml_tensor * llm_build_lora_mm( + struct llama_context & lctx, + struct ggml_context * ctx0, + struct ggml_tensor * w, + struct ggml_tensor * cur) { + struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); + for (auto & it : lctx.lora_adapters) { + struct llama_lora_weight * lora = it.first->get_weight(w); + if (lora == nullptr) { + continue; + } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; + struct ggml_tensor * ab_cur = ggml_mul_mat( + ctx0, lora->b, + ggml_mul_mat(ctx0, lora->a, cur) + ); + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + return res; +} + +// do mat_mul_id, while optionally apply lora +static struct ggml_tensor * llm_build_lora_mm_id( + struct llama_context & lctx, + struct ggml_context * ctx0, + struct ggml_tensor * w, // struct ggml_tensor * as + struct ggml_tensor * cur, // struct ggml_tensor * b + struct ggml_tensor * ids) { + struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + for (auto & it : lctx.lora_adapters) { + struct llama_lora_weight * lora = it.first->get_weight(w); + if (lora == nullptr) { + continue; + } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; + struct ggml_tensor * ab_cur = ggml_mul_mat_id( + ctx0, lora->b, + ggml_mul_mat_id(ctx0, lora->a, cur, ids), + ids + ); + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + return res; +} + static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -7843,6 +7955,7 @@ static struct ggml_tensor * llm_build_norm( static struct ggml_tensor * llm_build_ffn( struct ggml_context * ctx, + struct llama_context & lctx, struct ggml_tensor * cur, struct ggml_tensor * up, struct ggml_tensor * up_b, @@ -7858,7 +7971,7 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { - struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur; + struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -7875,12 +7988,12 @@ static struct ggml_tensor * llm_build_ffn( switch (type_gate) { case LLM_FFN_SEQ: { - cur = ggml_mul_mat(ctx, gate, tmp); + cur = llm_build_lora_mm(lctx, ctx, gate, tmp); cb(cur, "ffn_gate", il); } break; case LLM_FFN_PAR: { - cur = ggml_mul_mat(ctx, gate, cur); + cur = llm_build_lora_mm(lctx, ctx, gate, cur); cb(cur, "ffn_gate", il); } break; } @@ -7948,7 +8061,7 @@ static struct ggml_tensor * llm_build_ffn( } if (down) { - cur = ggml_mul_mat(ctx, down, cur); + cur = llm_build_lora_mm(lctx, ctx, down, cur); } if (down_b) { @@ -7969,6 +8082,7 @@ static struct ggml_tensor * llm_build_ffn( static struct ggml_tensor * llm_build_moe_ffn( struct ggml_context * ctx, + struct llama_context & lctx, struct ggml_tensor * cur, struct ggml_tensor * gate_inp, struct ggml_tensor * up_exps, @@ -7985,7 +8099,7 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; - ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens] + ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] @@ -8017,10 +8131,10 @@ static struct ggml_tensor * llm_build_moe_ffn( } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); - ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); switch (type_op) { @@ -8041,7 +8155,7 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] cb(par, "ffn_moe_gate_par", il); - ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); experts = ggml_mul(ctx, experts, weights); @@ -8069,9 +8183,7 @@ static struct ggml_tensor * llm_build_moe_ffn( static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, - const llama_model & model, - const llama_hparams & hparams, - const llama_cparams & cparams, + struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -8083,6 +8195,10 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = lctx.model.hparams; + const llama_cparams & cparams = lctx.cparams; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); @@ -8181,7 +8297,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_build_forward_expand(graph, cur); if (wo) { - cur = ggml_mul_mat(ctx, wo, cur); + cur = llm_build_lora_mm(lctx, ctx, wo, cur); } if (wo_b) { @@ -8197,9 +8313,7 @@ static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, - const llama_model & model, - const llama_hparams & hparams, - const llama_cparams & cparams, + struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -8214,6 +8328,8 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il) { + const llama_hparams & hparams = lctx.model.hparams; + const llama_cparams & cparams = lctx.cparams; // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced @@ -8225,7 +8341,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); @@ -8687,21 +8803,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -8722,7 +8838,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8745,7 +8861,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -8759,7 +8875,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -8789,7 +8905,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -8825,13 +8941,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); switch (model.type) { @@ -8857,7 +8973,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8879,7 +8995,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -8904,7 +9020,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -8940,13 +9056,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -8962,7 +9078,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8984,7 +9100,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9007,7 +9123,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9056,7 +9172,7 @@ struct llm_build_context { cur = attn_norm; } - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); @@ -9083,7 +9199,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9100,7 +9216,7 @@ struct llm_build_context { // feed forward { - cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result + cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9127,7 +9243,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9172,21 +9288,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -9207,7 +9323,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -9239,7 +9355,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -9278,7 +9394,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // Grok // multiply logits by output_multiplier_scale of 0.5773502691896257 @@ -9329,7 +9445,7 @@ struct llm_build_context { struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -9357,7 +9473,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9380,7 +9496,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_out_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -9409,7 +9525,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); @@ -9451,7 +9567,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -9467,7 +9583,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9491,7 +9607,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -9514,7 +9630,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9546,13 +9662,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); @@ -9561,7 +9677,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9583,7 +9699,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9608,7 +9724,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9660,7 +9776,7 @@ struct llm_build_context { // self-attention if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); if (model.layers[il].attn_q_norm) { @@ -9670,7 +9786,7 @@ struct llm_build_context { LLM_NORM, cb, il); } - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk); cb(Kcur, "Kcur", il); if (model.layers[il].attn_k_norm) { @@ -9679,14 +9795,14 @@ struct llm_build_context { model.layers[il].attn_k_norm_b, LLM_NORM, cb, il); } - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { // compute Q and K and RoPE them - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); @@ -9735,7 +9851,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); if (model.layers[il].bo) { cb(cur, "kqv_wo", il); } @@ -9768,21 +9884,21 @@ struct llm_build_context { // feed-forward network if (model.arch == LLM_ARCH_BERT) { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9840,7 +9956,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -9856,7 +9972,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9880,7 +9996,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -9903,7 +10019,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9950,7 +10066,7 @@ struct llm_build_context { { cur = attn_norm; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); if (model.layers[il].bqkv){ @@ -9988,13 +10104,13 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10018,7 +10134,7 @@ struct llm_build_context { model.layers[il].ffn_norm_b, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -10043,7 +10159,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10083,21 +10199,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -10139,7 +10255,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10167,7 +10283,7 @@ struct llm_build_context { // parallel residual cur = inpSA; } - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10193,7 +10309,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10228,7 +10344,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10258,7 +10374,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10280,7 +10396,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10305,7 +10421,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10343,17 +10459,17 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -10372,7 +10488,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10393,7 +10509,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10417,7 +10533,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10458,17 +10574,17 @@ struct llm_build_context { // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -10487,7 +10603,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10510,7 +10626,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); ggml_tensor * moe_out = - llm_build_moe_ffn(ctx0, cur, + llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -10523,14 +10639,14 @@ struct llm_build_context { // FFN shared expert { - ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur); + ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur); cb(cur_gate_inp, "ffn_shexp_gate_inp", il); // sigmoid ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); cb(cur_gate, "ffn_shexp_gate", il); - ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, + ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -10563,7 +10679,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10605,7 +10721,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = nullptr; if (model.layers[il].wqkv) { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10615,9 +10731,9 @@ struct llm_build_context { Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } else { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); } cb(Qcur, "Qcur", il); @@ -10644,7 +10760,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -10659,7 +10775,7 @@ struct llm_build_context { // FF { - ffn_output = llm_build_ffn(ctx0, attn_norm_output, + ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -10683,7 +10799,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); @@ -10729,7 +10845,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = nullptr; if (model.layers[il].wqkv) { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); @@ -10737,9 +10853,9 @@ struct llm_build_context { Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); } else { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); } cb(Qcur, "Qcur", il); @@ -10764,7 +10880,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -10788,7 +10904,7 @@ struct llm_build_context { // special-case: the up and gate tensors are merged into a single tensor // TOOD: support into llm_build_ffn { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10811,7 +10927,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10851,13 +10967,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -10872,7 +10988,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10890,7 +11006,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10916,7 +11032,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10958,7 +11074,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10974,7 +11090,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10998,7 +11114,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11021,7 +11137,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11057,7 +11173,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -11085,7 +11201,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11109,7 +11225,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11132,7 +11248,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11170,21 +11286,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); // if (model.layers[il].bq) { // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); // cb(Qcur, "Qcur", il); // } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); // if (model.layers[il].bk) { // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); // cb(Kcur, "Kcur", il); // } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); // if (model.layers[il].bv) { // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11205,7 +11321,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11226,7 +11342,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11250,7 +11366,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11288,21 +11404,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11323,7 +11439,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11344,7 +11460,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11368,7 +11484,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11419,21 +11535,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11454,7 +11570,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11481,7 +11597,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11515,7 +11631,7 @@ struct llm_build_context { cb(cur, "lmhead_scaling", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11552,13 +11668,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -11576,7 +11692,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -11598,7 +11714,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11623,7 +11739,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11665,13 +11781,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -11694,7 +11810,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -11721,7 +11837,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11751,7 +11867,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); @@ -11796,21 +11912,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11831,7 +11947,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11853,7 +11969,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11877,7 +11993,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11929,7 +12045,7 @@ struct llm_build_context { cb(cur, "attn_norm", il); // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); + struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur); // split the above in two // => {d_inner, n_tokens} struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); @@ -11968,14 +12084,14 @@ struct llm_build_context { // ssm { // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); + struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); + dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); // Custom operator to optimize the parallel associative scan @@ -12006,7 +12122,7 @@ struct llm_build_context { y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y); } // residual @@ -12025,7 +12141,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12064,21 +12180,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -12124,7 +12240,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12141,7 +12257,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, ffn_inp, + cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12168,7 +12284,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); @@ -12221,21 +12337,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (hparams.f_clamp_kqv > 0.0f) { Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (hparams.f_clamp_kqv > 0.0f) { Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (hparams.f_clamp_kqv > 0.0f) { Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -12256,7 +12372,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12278,7 +12394,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12304,7 +12420,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12344,7 +12460,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); @@ -12383,7 +12499,7 @@ struct llm_build_context { Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens); cb(Qcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12405,7 +12521,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12429,7 +12545,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12464,7 +12580,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -12492,7 +12608,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12517,7 +12633,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -12548,7 +12664,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -12571,7 +12687,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12612,13 +12728,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -12635,7 +12751,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12657,7 +12773,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12674,7 +12790,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm_exps", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -12703,7 +12819,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12857,7 +12973,7 @@ struct llm_build_context { struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -12873,13 +12989,13 @@ struct llm_build_context { struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - if ((uint32_t) il < hparams.n_layer_dense_lead) { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12888,13 +13004,8 @@ struct llm_build_context { cb(cur, "ffn_out", il); } else { // MoE branch - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = - llm_build_moe_ffn(ctx0, cur, + llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -12907,7 +13018,7 @@ struct llm_build_context { // FFN shared expert { - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -12972,7 +13083,7 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -12981,7 +13092,7 @@ struct llm_build_context { } // B1.K - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { @@ -12990,7 +13101,7 @@ struct llm_build_context { } // B1.V - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { @@ -13012,7 +13123,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, NULL, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); @@ -13021,7 +13132,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_sub_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); if (model.layers[il].bo) { cur = ggml_add(ctx0, cur, model.layers[il].bo); @@ -13045,7 +13156,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, NULL, NULL, NULL, @@ -13058,7 +13169,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_sub_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); @@ -13077,7 +13188,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -13179,7 +13290,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_enc, NULL, NULL, model.layers[il].ffn_gate_enc, NULL, NULL, model.layers[il].ffn_down_enc, NULL, NULL, @@ -13359,7 +13470,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -13425,7 +13536,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -13441,7 +13552,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il); } @@ -13465,7 +13576,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -13484,7 +13595,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); @@ -13526,7 +13637,7 @@ struct llm_build_context { struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -13554,7 +13665,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur_rope", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); @@ -13579,7 +13690,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -13599,7 +13710,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -18463,284 +18574,212 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -static int llama_apply_lora_from_file_internal( - const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads -) { - LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); +static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); - const int64_t t_start_lora_us = ggml_time_us(); - - llama_file fin(path_lora, "rb"); - - // verify magic and version - { - uint32_t magic = fin.read_u32(); - if (magic != LLAMA_FILE_MAGIC_GGLA) { - LLAMA_LOG_ERROR("%s: bad file magic\n", __func__); - return 1; - } - - uint32_t format_version = fin.read_u32(); - if (format_version != 1) { - LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); - return 1; - } - } - - int32_t lora_r = fin.read_u32(); - int32_t lora_alpha = fin.read_u32(); - float scaling = scale * (float)lora_alpha / (float)lora_r; - - LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); - - // load base model - std::unique_ptr ml; - if (path_base_model) { - LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr)); - ml->init_mappings(/*prefetch*/ false); // no prefetching - } - - struct tensor_meta { - std::string name; - ggml_type type; - int32_t ne[2]; - size_t offset; + ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &ctx, }; - std::map tensor_meta_map; - - // load all tensor meta - while (true) { - if (fin.tell() == fin.size) { - // eof - break; - } - - int32_t n_dims; - int32_t name_len; - int32_t ftype; - - fin.read_raw(&n_dims, sizeof(n_dims)); - fin.read_raw(&name_len, sizeof(name_len)); - fin.read_raw(&ftype, sizeof(ftype)); - - if (n_dims != 1 && n_dims != 2) { - LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); - return 1; - } - - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read_raw(&ne[i], sizeof(ne[i])); - } - - std::string name; - { - GGML_ASSERT(name_len < GGML_MAX_NAME); - char buf[GGML_MAX_NAME]; - fin.read_raw(buf, name_len); - name = std::string(buf, name_len); - } - - // check for lora suffix - std::string lora_suffix; - if (name.length() > 6) { - lora_suffix = name.substr(name.length() - 6); - } - if (lora_suffix != ".loraA" && lora_suffix != ".loraB") { - LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); - return 1; - } - - // tensor type - ggml_type wtype; - switch (ftype) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - default: - { - LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", - __func__, ftype); - return 1; - } - } - - // data offset - size_t offset = fin.tell(); - offset = (offset + 31) & -32; - - // skip tensor data - fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET); - - tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset }); + struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params); + if (!ctx_gguf) { + throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora)); } - bool warned = false; - int n_tensors = 0; - - // apply - ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - if (backend_cpu == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__); - return 1; - } - ggml_backend_cpu_set_n_threads(backend_cpu, n_threads); - - std::vector> read_buf; - for (const auto & it : model.tensors_by_name) { - const std::string & base_name = it.first; - ggml_tensor * model_t = it.second; - - if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() || - tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) { - continue; - } - - tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA"); - tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB"); - - ggml_init_params lora_init_params = { - /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), - /* .mem_buffer */ nullptr, - /* .no_alloc */ true, + // check metadata + { + auto get_kv_str = [&](const std::string & key) -> std::string { + int id = gguf_find_key(ctx_gguf, key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); }; - ggml_context * lora_ctx = ggml_init(lora_init_params); - if (lora_ctx == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__); - ggml_backend_free(backend_cpu); - return 1; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id); + }; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE)); + if (general_type != "adapter") { + gguf_free(ctx_gguf); + throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type); } - // create tensors - ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]); - ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]); - ggml_set_name(loraA, metaA.name.c_str()); - ggml_set_name(loraB, metaB.name.c_str()); + auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); + auto general_arch = llm_arch_from_string(general_arch_str); + if (general_arch != model->arch) { + gguf_free(ctx_gguf); + throw std::runtime_error("model arch and LoRA arch mismatch"); + } - ggml_tensor * base_t; - if (ml) { - if (!ml->get_tensor_meta(base_name.c_str())) { - LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); - return 1; + auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE)); + if (adapter_type != "lora") { + gguf_free(ctx_gguf); + throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type); + } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + } + + int n_tensors = gguf_get_n_tensors(ctx_gguf); + + // contexts for each buffer type + std::map ctx_map; + auto get_ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // add a new context + struct ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * buft_ctx = ggml_init(params); + ctx_map[buft] = buft_ctx; + return buft_ctx; + }; + return it->second; + }; + + // bundle lora_a and lora_b into pairs + std::map ab_map; + auto str_endswith = [](const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; + }; + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string name(cur->name); + if (str_endswith(name, ".lora_a")) { + replace_all(name, ".lora_a", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(cur, nullptr); + } else { + ab_map[name].a = cur; + } + } else if (str_endswith(name, ".lora_b")) { + replace_all(name, ".lora_b", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(nullptr, cur); + } else { + ab_map[name].b = cur; } - base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str())); } else { - base_t = ggml_dup_tensor(lora_ctx, model_t); - } - ggml_set_name(base_t, base_name.c_str()); - - // allocate in backend buffer - ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type()); - if (lora_buf == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__); - return 1; - } - - // load tensor data - auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) { - read_buf.resize(ggml_nbytes(tensor)); - fin.seek(tensor_meta.offset, SEEK_SET); - fin.read_raw(read_buf.data(), ggml_nbytes(tensor)); - ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size()); - }; - load_tensor(metaA, loraA); - load_tensor(metaB, loraB); - - // load base model tensor data - if (ml) { - ml->load_data_for(base_t); - } else { - ggml_backend_tensor_copy(model_t, base_t); - } - - if (ggml_is_quantized(base_t->type) && !warned) { - LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, " - "use a f16 or f32 base model with --lora-base\n", __func__); - warned = true; - } - - if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { - LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); - ggml_free(lora_ctx); - ggml_backend_buffer_free(lora_buf); - ggml_backend_free(backend_cpu); - return 1; - } - - auto build_lora_graph = [&]() { - // w = w + BA*s - ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); - ggml_set_name(BA, "BA"); - - if (scaling != 1.0f) { - BA = ggml_scale(lora_ctx, BA, scaling); - ggml_set_name(BA, "BA_scaled"); - } - - ggml_tensor * r; - r = ggml_add_inplace(lora_ctx, base_t, BA); - ggml_set_name(r, "r_add"); - - if (base_t->type != model_t->type) { - // convert the result to the model type - r = ggml_cast(lora_ctx, r, model_t->type); - ggml_set_name(r, "r_cast"); - } - - return r; - }; - - ggml_cgraph * gf = ggml_new_graph(lora_ctx); - ggml_tensor * r = build_lora_graph(); - ggml_build_forward_expand(gf, r); - - ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type()); - if (graph_buf == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__); - ggml_free(lora_ctx); - ggml_backend_buffer_free(lora_buf); - ggml_backend_free(backend_cpu); - return 1; - } - - ggml_backend_graph_compute(backend_cpu, gf); - - ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r)); - -#if 0 - // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU - //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE); - - // sched compute - ggml_build_forward_expand(gf, build_graph()); - ggml_backend_sched_init_measure(sched, gf); - - // create the graph again, since the previous one was destroyed by the measure - ggml_graph_clear(gf); - ggml_build_forward_expand(gf, build_graph()); - ggml_backend_sched_graph_compute(sched, gf); - ggml_backend_sched_free(sched); -#endif - - ggml_backend_buffer_free(lora_buf); - ggml_backend_buffer_free(graph_buf); - ggml_free(lora_ctx); - - n_tensors++; - if (n_tensors % 4 == 0) { - LLAMA_LOG_INFO("."); + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); } } - ggml_backend_free(backend_cpu); + // add tensors + for (auto & it : ab_map) { + const std::string & name = it.first; + llama_lora_weight & w = it.second; - const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; - LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); + if (!w.a || !w.b) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); + } + // device buft and device ctx + auto * model_tensor = llama_get_model_tensor(model, name.c_str()); + if (!model_tensor) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); + } + struct ggml_context * dev_ctx = get_ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + // validate tensor shape + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("tensor '" + name + "' has incorrect shape"); + } + if (w.a->ne[1] != w.b->ne[0]) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } + // save tensor to adapter + struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_set_name(tensor_a, w.a->name); + ggml_set_name(tensor_b, w.b->name); + adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + } + + // allocate tensors / buffers and zero + { + adapter.ctxs.reserve(ctx_map.size()); + adapter.bufs.reserve(ctx_map.size()); + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx_dev = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft); + if (!buf) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); + } + LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + adapter.ctxs.push_back(ctx_dev); + adapter.bufs.push_back(buf); + } + } + + // set tensor data + { + llama_file gguf_file(path_lora, "rb"); + std::vector read_buf; + auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) { + size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name)); + size_t size = ggml_nbytes(orig); + read_buf.resize(size); + gguf_file.seek(offs, SEEK_SET); + gguf_file.read_raw(read_buf.data(), size); + ggml_backend_tensor_set(dev, read_buf.data(), 0, size); + }; + for (auto & it : adapter.ab_map) { + auto orig = ab_map[it.first]; + auto dev = it.second; + set_tensor(orig.a, dev.a); + set_tensor(orig.b, dev.b); + } + } + + LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2); + + // free ctx for reading gguf + gguf_free(ctx_gguf); + ggml_free(ctx); +} + +int32_t llama_lora_adapter_set( + struct llama_context * ctx, + struct llama_lora_adapter * adapter, + float scale) { + if (ctx->cparams.flash_attn) { + LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__); + return -1; + } + ctx->lora_adapters[adapter] = scale; return 0; } +int32_t llama_lora_adapter_remove( + struct llama_context * ctx, + struct llama_lora_adapter * adapter) { + auto pos = ctx->lora_adapters.find(adapter); + if (pos != ctx->lora_adapters.end()) { + ctx->lora_adapters.erase(pos); + return 0; + } + return -1; +} + +void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { + delete adapter; +} + // // interface implementation // @@ -19519,12 +19558,14 @@ uint32_t llama_model_quantize( } } -int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) { +struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { try { - return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads); + struct llama_lora_adapter * adapter = new llama_lora_adapter(model); + llama_lora_adapter_init_internal(model, path_lora, *adapter); + return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); - return 1; + return nullptr; } }