Merge branch 'master' into compilade/faster-lazy-safetensors

This commit is contained in:
Francis Couture-Harpin 2024-07-15 15:24:25 -04:00
commit 2a49a68d70
25 changed files with 1531 additions and 720 deletions

View File

@ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--lora") {
CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false;
return true;
}
if (arg == "--lora-scaled") {
@ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
const char* lora_adapter = argv[i];
CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false;
return true;
}
if (arg == "--lora-base") {
@ -797,6 +795,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cont_batching = true;
return true;
}
if (arg == "-nocb" || arg == "--no-cont-batching") {
params.cont_batching = false;
return true;
}
if (arg == "-fa" || arg == "--flash-attn") {
params.flash_attn = true;
return true;
@ -1538,6 +1540,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel });
options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences });
options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" });
options.push_back({ "*", "-nocb, --no-cont-batching", "disable continuous batching" });
options.push_back({ "multi-modality" });
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
@ -2084,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(),
lora_scale,
((i > 0) || params.lora_base.empty())
? NULL
: params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
llama_lora_adapter_set(lctx, adapter, lora_scale);
}
if (params.ignore_eos) {

View File

@ -2271,13 +2271,6 @@ class InternLM2Model(Model):
special_vocab.add_to_gguf(self.gguf_writer)
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@ -2297,26 +2290,22 @@ class InternLM2Model(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"]
n_embd = self.hparams["hidden_size"]
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
head_dim = n_embd // num_heads
num_groups = num_heads // q_per_kv
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
if re.match(qkv_pattern, name):
bid = re.findall(qkv_pattern, name)[0]
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
qkv = data_torch
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
# The model weights of q and k equire additional reshape.
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
# v = rearrange(v, " o g n i -> o (g n i)").T
v = v.reshape((v.shape[0], -1)).T
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
v = v.reshape((-1, v.shape[-1]))
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
@ -3620,6 +3609,7 @@ def main() -> None:
small_first_shard=args.no_tensor_first_split)
logger.info("Set model parameters")
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
model_instance.set_gguf_parameters()
logger.info("Set model tokenizer")

374
convert_lora_to_gguf.py Executable file
View File

@ -0,0 +1,374 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
import logging
import argparse
import os
import sys
import json
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
import torch
if TYPE_CHECKING:
from torch import Tensor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, Model
logger = logging.getLogger("lora-to-gguf")
@dataclass
class PartialLoraTensor:
A: Tensor | None = None
B: Tensor | None = None
# magic to support tensor shape modifications and splitting
class LoraTorchTensor:
_lora_A: Tensor # (n_rank, row_size)
_lora_B: Tensor # (col_size, n_rank)
_rank: int
def __init__(self, A: Tensor, B: Tensor):
assert len(A.shape) == len(B.shape)
assert A.shape[-2] == B.shape[-1]
if A.dtype != B.dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
self._lora_A = A
self._lora_B = B
self._rank = B.shape[-1]
def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
return (self._lora_A, self._lora_B)
def __getitem__(
self,
indices: (
SupportsIndex
| slice
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
),
) -> LoraTorchTensor:
shape = self.shape
if isinstance(indices, SupportsIndex):
if len(shape) > 2:
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
else:
raise NotImplementedError # can't return a vector
elif isinstance(indices, slice):
if len(shape) > 2:
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
else:
return LoraTorchTensor(self._lora_A, self._lora_B[indices])
elif isinstance(indices, tuple):
assert len(indices) > 0
if indices[-1] is Ellipsis:
return self[indices[:-1]]
# expand ellipsis
indices = tuple(
u
for v in (
(
(slice(None, None) for _ in range(len(indices) - 1))
if i is Ellipsis
else (i,)
)
for i in indices
)
for u in v
)
if len(indices) < len(shape):
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
# TODO: make sure this is correct
indices_A = (
*(
(
j.__index__() % self._lora_A.shape[i]
if isinstance(j, SupportsIndex)
else slice(None, None)
)
for i, j in enumerate(indices[:-2])
),
slice(None, None),
indices[-1],
)
indices_B = indices[:-1]
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
else:
raise NotImplementedError # unknown indice type
@property
def dtype(self) -> torch.dtype:
assert self._lora_A.dtype == self._lora_B.dtype
return self._lora_A.dtype
@property
def shape(self) -> tuple[int, ...]:
assert len(self._lora_A.shape) == len(self._lora_B.shape)
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
def size(self, dim=None):
assert dim is None
return self.shape
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
if isinstance(shape[0], tuple):
new_shape: tuple[int, ...] = shape[0]
else:
new_shape = cast(tuple[int, ...], shape)
orig_shape = self.shape
if len(new_shape) < 2:
raise NotImplementedError # can't become a vector
# expand -1 in the shape
if any(dim == -1 for dim in new_shape):
n_elems = prod(orig_shape)
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
assert n_elems % n_new_elems == 0
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
if new_shape[-1] != orig_shape[-1]:
raise NotImplementedError # can't reshape the row size trivially
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
shape_B = (*new_shape[:-1], self._rank)
return LoraTorchTensor(
self._lora_A.reshape(shape_A),
self._lora_B.reshape(shape_B),
)
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
return self.reshape(*other.shape)
def view(self, *size: int) -> LoraTorchTensor:
return self.reshape(*size)
def permute(self, *dims: int) -> LoraTorchTensor:
shape = self.shape
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
if dims[-1] == -1:
# TODO: support higher dimensional A shapes bigger than 1
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
else:
# TODO: compose the above two
raise NotImplementedError
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
shape = self.shape
dims = [i for i in range(len(shape))]
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return self.permute(*dims)
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1)
def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
@classmethod
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
del types # unused
if kwargs is None:
kwargs = {}
if func is torch.permute:
return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape:
return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack:
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
return LoraTorchTensor(
torch.stack([a._lora_A for a in args[0]], dim),
torch.stack([b._lora_B for b in args[0]], dim),
)
elif func is torch.cat:
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
if len(args[0][0].shape) > 2:
return LoraTorchTensor(
torch.cat([a._lora_A for a in args[0]], dim),
torch.cat([b._lora_B for b in args[0]], dim),
)
elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
return LoraTorchTensor(
args[0][0]._lora_A,
torch.cat([b._lora_B for b in args[0]], dim),
)
else:
raise NotImplementedError
else:
raise NotImplementedError
def get_base_tensor_name(lora_tensor_name: str) -> str:
base_name = lora_tensor_name.replace("base_model.model.", "")
base_name = base_name.replace(".lora_A.weight", ".weight")
base_name = base_name.replace(".lora_B.weight", ".weight")
return base_name
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"--no-lazy", action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
)
parser.add_argument(
"--base", type=Path, required=True,
help="directory containing base model file",
)
parser.add_argument(
"lora_path", type=Path,
help="directory containing LoRA adapter file",
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"auto": gguf.LlamaFileType.GUESSED,
}
ftype = ftype_map[args.outtype]
dir_base_model: Path = args.base
dir_lora: Path = args.lora_path
lora_config = dir_lora / "adapter_config.json"
input_model = dir_lora / "adapter_model.safetensors"
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
if os.path.exists(input_model):
# lazy import load_file only if lora is in safetensors format.
from safetensors.torch import load_file
lora_model = load_file(input_model, device="cpu")
else:
input_model = os.path.join(dir_lora, "adapter_model.bin")
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
# load base model
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = Model.load_hparams(dir_base_model)
with torch.inference_mode():
try:
model_class = Model.from_model_architecture(hparams["architectures"][0])
except NotImplementedError:
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)
class LoraModel(model_class):
model_arch = model_class.model_arch
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map: dict[str, PartialLoraTensor] = {}
for name, tensor in lora_model.items():
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name:
continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
sys.exit(1)
if base_name in tensor_map:
if is_lora_a:
tensor_map[base_name].A = tensor
else:
tensor_map[base_name].B = tensor
else:
if is_lora_a:
tensor_map[base_name] = PartialLoraTensor(A=tensor)
else:
tensor_map[base_name] = PartialLoraTensor(B=tensor)
for name, tensor in tensor_map.items():
assert tensor.A is not None
assert tensor.B is not None
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
dest = super().modify_tensors(data_torch, name, bid)
for dest_name, dest_data in dest:
assert isinstance(dest_data, LoraTorchTensor)
lora_a, lora_b = dest_data.get_lora_A_B()
yield (dest_name + ".lora_a", lora_a)
yield (dest_name + ".lora_b", lora_b)
model_instance = LoraModel(
dir_base_model,
ftype,
fname_out,
is_big_endian=args.bigendian,
use_temp_file=False,
eager=args.no_lazy,
model_name=None,
)
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
alpha = lparams["lora_alpha"]
model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...")
model_instance.write()
logger.info(f"Model successfully exported to {model_instance.fname_out}")

View File

@ -9,15 +9,15 @@ Adding a model requires few steps:
After following these steps, you can open PR.
Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially:
- [main](../examples/main)
- [imatrix](../examples/imatrix)
- [quantize](../examples/quantize)
- [server](../examples/server)
- [main](/examples/main/)
- [imatrix](/examples/imatrix/)
- [quantize](/examples/quantize/)
- [server](/examples/server/)
### 1. Convert the model to GGUF
This step is done in python with a `convert` script using the [gguf](https://pypi.org/project/gguf/) library.
Depending on the model architecture, you can use either [convert_hf_to_gguf.py](../convert_hf_to_gguf.py) or [examples/convert_legacy_llama.py](../examples/convert_legacy_llama.py) (for `llama/llama2` models in `.pth` format).
Depending on the model architecture, you can use either [convert_hf_to_gguf.py](/convert_hf_to_gguf.py) or [examples/convert_legacy_llama.py](/examples/convert_legacy_llama.py) (for `llama/llama2` models in `.pth` format).
The convert script reads the model configuration, tokenizer, tensor names+data and converts them to GGUF metadata and tensors.
@ -31,7 +31,7 @@ class MyModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
```
2. Define the layout of the GGUF tensors in [constants.py](../gguf-py/gguf/constants.py)
2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
Add an enum entry in `MODEL_ARCH`, the model human friendly name in `MODEL_ARCH_NAMES` and the GGUF tensor names in `MODEL_TENSORS`.
@ -54,7 +54,7 @@ Example for `falcon` model:
As a general rule, before adding a new tensor name to GGUF, be sure the equivalent naming does not already exist.
Once you have found the GGUF tensor name equivalent, add it to the [tensor_mapping.py](../gguf-py/gguf/tensor_mapping.py) file.
Once you have found the GGUF tensor name equivalent, add it to the [tensor_mapping.py](/gguf-py/gguf/tensor_mapping.py) file.
If the tensor name is part of a repetitive layer/block, the key word `bid` substitutes it.
@ -100,7 +100,7 @@ Have a look at existing implementation like `build_llama`, `build_dbrx` or `buil
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
Note: to debug the inference graph: you can use [llama-eval-callback](../examples/eval-callback).
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
## GGUF specification

View File

@ -1,7 +1,7 @@
# Token generation performance troubleshooting
## Verifying that the model is running on the GPU with CUDA
Make sure you compiled llama with the correct env variables according to [this guide](../README.md#CUDA), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example:
Make sure you compiled llama with the correct env variables according to [this guide](/docs/build.md#cuda), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example:
```shell
./llama-cli -m "path/to/model.gguf" -ngl 200000 -p "Please sir, may I have some "
```

View File

@ -6,7 +6,7 @@ import re
from copy import copy
from enum import Enum
from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse
from pydantic import BaseModel, create_model
@ -53,35 +53,38 @@ class PydanticDataType(Enum):
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
if isclass(pydantic_type) and issubclass(pydantic_type, str):
origin_type = get_origin(pydantic_type)
origin_type = pydantic_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, str):
return PydanticDataType.STRING.value
elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
elif isclass(origin_type) and issubclass(origin_type, bool):
return PydanticDataType.BOOLEAN.value
elif isclass(pydantic_type) and issubclass(pydantic_type, int):
elif isclass(origin_type) and issubclass(origin_type, int):
return PydanticDataType.INTEGER.value
elif isclass(pydantic_type) and issubclass(pydantic_type, float):
elif isclass(origin_type) and issubclass(origin_type, float):
return PydanticDataType.FLOAT.value
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
elif isclass(origin_type) and issubclass(origin_type, Enum):
return PydanticDataType.ENUM.value
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__)
elif get_origin(pydantic_type) is list:
elif isclass(origin_type) and issubclass(origin_type, BaseModel):
return format_model_and_field_name(origin_type.__name__)
elif origin_type is list:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) is set:
elif origin_type is set:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) is Union:
elif origin_type is Union:
union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}"
elif get_origin(pydantic_type) is Optional:
elif origin_type is Optional:
element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
elif get_origin(pydantic_type) is dict:
elif isclass(origin_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
elif origin_type is dict:
key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else:
@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
for name, param_type in cls.__annotations__.items()
for name, param_type in get_type_hints(cls).items()
if name != "self"
]
@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
field_name = format_model_and_field_name(field_name)
gbnf_type = map_pydantic_type_to_gbnf(field_type)
if isclass(field_type) and issubclass(field_type, BaseModel):
origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules
elif isclass(field_type) and issubclass(field_type, Enum):
elif isclass(origin_type) and issubclass(origin_type, Enum):
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array
elif origin_type is list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == set or field_type == set: # Array
elif origin_type is set: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
gbnf_type = f"{model_name}-{field_name}-optional"
else:
gbnf_type = f"{model_name}-{field_name}-union"
elif isclass(field_type) and issubclass(field_type, str):
elif isclass(origin_type) and issubclass(origin_type, str):
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
gbnf_type = PydanticDataType.STRING.value
elif (
isclass(field_type)
and issubclass(field_type, float)
isclass(origin_type)
and issubclass(origin_type, float)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
)
elif (
isclass(field_type)
and issubclass(field_type, int)
isclass(origin_type)
and issubclass(origin_type, int)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__:
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else:
init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters
@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
name != "self"}
else:
# For Pydantic models, use model_fields and check for ellipsis (required fields)
model_fields = model.__annotations__
model_fields = get_type_hints(model)
model_rule_parts = []
nested_rules = []
@ -706,7 +712,7 @@ def generate_markdown_documentation(
else:
documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items():
for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
@ -754,14 +760,17 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else ""
if get_origin(field_type) == list:
origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if origin_type == list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "":
field_text += ":\n"
else:
field_text += "\n"
elif get_origin(field_type) == Union:
elif origin_type == Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
@ -792,9 +801,9 @@ def generate_field_markdown(
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
field_text += f"{indent} Example: {example_text}\n"
if isclass(field_type) and issubclass(field_type, BaseModel):
if isclass(origin_type) and issubclass(origin_type, BaseModel):
field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items():
for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
return field_text
@ -855,7 +864,7 @@ def generate_text_documentation(
if isclass(model) and issubclass(model, BaseModel):
documentation_fields = ""
for name, field_type in model.__annotations__.items():
for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
@ -948,7 +957,7 @@ def generate_field_text(
if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items():
for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2)
return field_text

View File

@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
data = response.json()
assert data.get("error") is None, data
print(data["content"])
return data["content"]

View File

@ -15,69 +15,281 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggerganov/llama.cpp/issues/4216).
**Command line options:**
## Usage
- `-v`, `--verbose`: Enable verbose server output. When using the `/completion` endpoint, this includes the tokenized prompt, the full request and the full response.
- `-t N`, `--threads N`: Set the number of threads to use by CPU layers during generation. Not used by model layers that are offloaded to GPU. This option has no effect when using the maximum number of GPU layers. Default: `std::thread::hardware_concurrency()` (number of CPU cores).
- `-tb N, --threads-batch N`: Set the number of threads to use by CPU layers during batch and prompt processing (>= 32 tokens). This option has no effect if a GPU is available. Default: `--threads`.
- `--threads-http N`: Number of threads in the http server pool to process requests. Default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file. Default: unused
- `-hfr REPO, --hf-repo REPO`: Hugging Face model repository. Default: unused
- `-hff FILE, --hf-file FILE`: Hugging Face model file. Default: unused
- `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is `512`, but LLaMA models were built with a context of `2048`, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of `4096`.
- `-ngl N`, `--n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs, this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default, GPU `0` is used.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs, this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default, the data is split in proportion to VRAM, but this may not be optimal for performance.
- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `2048`
- `-ub N`, `--ubatch-size N`: Physical maximum batch size. Default: `512`
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
- `--numa STRATEGY`: Attempt one of the below optimization strategies that may help on some NUMA systems
- `--numa distribute`: Spread execution evenly over all nodes
- `--numa isolate`: Only spawn threads on CPUs on the node that execution started on
- `--numa numactl`: Use the CPU map provided by numactl. If run without this previously, it is recommended to drop the system page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/1437
- `--numa`: Attempt optimizations that may help on some NUMA systems.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`
- `--port`: Set the port to listen. Default: `8080`
- `--path`: Path from which to serve static files. Default: disabled
- `--api-key`: Set an api key for request authorization. By default, the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
- `--api-key-file`: Path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`s.
- `--embeddings`: Enable embedding vector output and the OAI compatible endpoint /v1/embeddings. Physical batch size (`--ubatch-size`) must be carefully defined. Default: disabled
- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1`. Values > 1 will allow for higher throughput with multiple parallel requests but the results will **not** be deterministic due to differences in rounding error.
- `-cb`, `--cont-batching`: Enable continuous batching (a.k.a dynamic batching). Default: disabled
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load a system prompt (initial prompt of all slots). This is useful for chat applications. [See more](#change-system-prompt-on-runtime)
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend. Used together with group attention width `--grp-attn-w`. Default: `1`, which is disabled.
- `--grp-attn-w`: Set the group attention width to extend context size through self-extend. Used together with group attention factor `--grp-attn-n`. Default: `512`
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn`
- `--rope-freq-base N` : RoPE frequency base (default: loaded from model)
- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25)
- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation)
- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0)
- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0)
- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls`
- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled)
- `-fa`, `--flash-attn` : enable flash attention (default: disabled).
- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`)
- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options)
- `--spm-infill` : Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
```
usage: ./llama-server [options]
general:
-h, --help, --usage print usage and exit
--version show version and build info
-v, --verbose print verbose information
--verbosity N set specific verbosity level (default: 0)
--verbose-prompt print a verbose prompt before generation (default: false)
--no-display-prompt don't print prompt at generation (default: false)
-co, --color colorise output to distinguish prompt and user input from generations (default: false)
-s, --seed SEED RNG seed (default: -1, use random seed for < 0)
-t, --threads N number of threads to use during generation (default: 8)
-tb, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)
-td, --threads-draft N number of threads to use during generation (default: same as --threads)
-tbd, --threads-batch-draft N number of threads to use during batch and prompt processing (default: same as --threads-draft)
--draft N number of tokens to draft for speculative decoding (default: 5)
-ps, --p-split N speculative decoding split probability (default: 0.1)
-lcs, --lookup-cache-static FNAME
path to static lookup cache to use for lookup decoding (not updated by generation)
-lcd, --lookup-cache-dynamic FNAME
path to dynamic lookup cache to use for lookup decoding (updated by generation)
-c, --ctx-size N size of the prompt context (default: 0, 0 = loaded from model)
-n, --predict N number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)
-b, --batch-size N logical maximum batch size (default: 2048)
-ub, --ubatch-size N physical maximum batch size (default: 512)
--keep N number of tokens to keep from the initial prompt (default: 0, -1 = all)
--chunks N max number of chunks to process (default: -1, -1 = all)
-fa, --flash-attn enable Flash Attention (default: disabled)
-p, --prompt PROMPT prompt to start generation with
in conversation mode, this will be used as system prompt
(default: '')
-f, --file FNAME a file containing the prompt (default: none)
--in-file FNAME an input file (repeat to specify multiple files)
-bf, --binary-file FNAME binary file containing the prompt (default: none)
-e, --escape process escapes sequences (\n, \r, \t, \', \", \\) (default: true)
--no-escape do not process escape sequences
-ptc, --print-token-count N print token count every N tokens (default: -1)
--prompt-cache FNAME file to cache prompt state for faster startup (default: none)
--prompt-cache-all if specified, saves user input and generations to cache as well
not supported with --interactive or other interactive options
--prompt-cache-ro if specified, uses the prompt cache but does not update it
-r, --reverse-prompt PROMPT halt generation at PROMPT, return control in interactive mode
can be specified more than once for multiple prompts
-sp, --special special tokens output enabled (default: false)
-cnv, --conversation run in conversation mode, does not print special tokens and suffix/prefix
if suffix/prefix are not specified, default chat template will be used
(default: false)
-i, --interactive run in interactive mode (default: false)
-if, --interactive-first run in interactive mode and wait for input right away (default: false)
-mli, --multiline-input allows you to write or paste multiple lines without ending each in '\'
--in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string
--in-prefix STRING string to prefix user inputs with (default: empty)
--in-suffix STRING string to suffix after user inputs with (default: empty)
--spm-infill use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled)
sampling:
--samplers SAMPLERS samplers that will be used for generation in the order, separated by ';'
(default: top_k;tfs_z;typical_p;top_p;min_p;temperature)
--sampling-seq SEQUENCE simplified sequence for samplers that will be used (default: kfypmt)
--ignore-eos ignore end of stream token and continue generating (implies --logit-bias EOS-inf)
--penalize-nl penalize newline tokens (default: false)
--temp N temperature (default: 0.8)
--top-k N top-k sampling (default: 40, 0 = disabled)
--top-p N top-p sampling (default: 0.9, 1.0 = disabled)
--min-p N min-p sampling (default: 0.1, 0.0 = disabled)
--tfs N tail free sampling, parameter z (default: 1.0, 1.0 = disabled)
--typical N locally typical sampling, parameter p (default: 1.0, 1.0 = disabled)
--repeat-last-n N last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size)
--repeat-penalty N penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled)
--presence-penalty N repeat alpha presence penalty (default: 0.0, 0.0 = disabled)
--frequency-penalty N repeat alpha frequency penalty (default: 0.0, 0.0 = disabled)
--dynatemp-range N dynamic temperature range (default: 0.0, 0.0 = disabled)
--dynatemp-exp N dynamic temperature exponent (default: 1.0)
--mirostat N use Mirostat sampling.
Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
--mirostat-lr N Mirostat learning rate, parameter eta (default: 0.1)
--mirostat-ent N Mirostat target entropy, parameter tau (default: 5.0)
-l TOKEN_ID(+/-)BIAS modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'
--cfg-negative-prompt PROMPT
negative prompt to use for guidance (default: '')
--cfg-negative-prompt-file FNAME
negative prompt file to use for guidance
--cfg-scale N strength of guidance (default: 1.0, 1.0 = disable)
--chat-template JINJA_TEMPLATE
set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
grammar:
--grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '')
--grammar-file FNAME file to read grammar from
-j, --json-schema SCHEMA JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead
embedding:
--pooling {none,mean,cls,last}
pooling type for embeddings, use model default if unspecified
--attention {causal,non-causal}
attention type for embeddings, use model default if unspecified
context hacking:
--rope-scaling {none,linear,yarn}
RoPE frequency scaling method, defaults to linear unless specified by the model
--rope-scale N RoPE context scaling factor, expands context by a factor of N
--rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
--rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N
--yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)
--yarn-ext-factor N YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
--yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
--yarn-beta-slow N YaRN: high correction dim or alpha (default: 1.0)
--yarn-beta-fast N YaRN: low correction dim or beta (default: 32.0)
-gan, --grp-attn-n N group-attention factor (default: 1)
-gaw, --grp-attn-w N group-attention width (default: 512.0)
-dkvc, --dump-kv-cache verbose print of the KV cache
-nkvo, --no-kv-offload disable KV offload
-ctk, --cache-type-k TYPE KV cache data type for K (default: f16)
-ctv, --cache-type-v TYPE KV cache data type for V (default: f16)
perplexity:
--all-logits return logits for all tokens in the batch (default: false)
--hellaswag compute HellaSwag score over random tasks from datafile supplied with -f
--hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: 400)
--winogrande compute Winogrande score over random tasks from datafile supplied with -f
--winogrande-tasks N number of tasks to use when computing the Winogrande score (default: 0)
--multiple-choice compute multiple choice score over random tasks from datafile supplied with -f
--multiple-choice-tasks N
number of tasks to use when computing the multiple choice score (default: 0)
--kl-divergence computes KL-divergence to logits provided via --kl-divergence-base
--ppl-stride N stride for perplexity calculation (default: 0)
--ppl-output-type {0,1} output type for perplexity calculation (default: 0)
parallel:
-dt, --defrag-thold N KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
-np, --parallel N number of parallel sequences to decode (default: 1)
-ns, --sequences N number of sequences to decode (default: 1)
-cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)
multi-modality:
--mmproj FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md
--image FILE path to an image file. use with multimodal models. Specify multiple times for batching
backend:
--rpc SERVERS comma separated list of RPC servers
--mlock force system to keep model in RAM rather than swapping or compressing
--no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)
--numa TYPE attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437
model:
--check-tensors check model tensor data for invalid values (default: false)
--override-kv KEY=TYPE:VALUE
advanced option to override model metadata by key. may be specified multiple times.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false
--lora FNAME apply LoRA adapter (implies --no-mmap)
--lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)
--lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter
--control-vector FNAME add a control vector
note: this argument can be repeated to add multiple control vectors
--control-vector-scaled FNAME SCALE
add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors
--control-vector-layer-range START END
layer range to apply the control vector(s) to, start and end inclusive
-m, --model FNAME model path (default: models/$filename with filename from --hf-file
or --model-url if set, otherwise models/7B/ggml-model-f16.gguf)
-md, --model-draft FNAME draft model for speculative decoding (default: unused)
-mu, --model-url MODEL_URL model download url (default: unused)
-hfr, --hf-repo REPO Hugging Face model repository (default: unused)
-hff, --hf-file FILE Hugging Face model file (default: unused)
-hft, --hf-token TOKEN Hugging Face access token (default: value from HF_TOKEN environment variable)
retrieval:
--context-file FNAME file to load context from (repeat to specify multiple files)
--chunk-size N minimum length of embedded text chunks (default: 64)
--chunk-separator STRING
separator between chunks (default: '
')
passkey:
--junk N number of times to repeat the junk text (default: 250)
--pos N position of the passkey in the junk text (default: -1)
imatrix:
-o, --output FNAME output file (default: 'imatrix.dat')
--output-frequency N output the imatrix every N iterations (default: 10)
--save-frequency N save an imatrix copy every N iterations (default: 0)
--process-output collect data for the output tensor (default: false)
--no-ppl do not compute perplexity (default: true)
--chunk N start processing the input from chunk N (default: 0)
bench:
-pps is the prompt shared across parallel sequences (default: false)
-npp n0,n1,... number of prompt tokens
-ntg n0,n1,... number of text generation tokens
-npl n0,n1,... number of parallel prompts
embedding:
--embd-normalize normalisation for embendings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
--embd-output-format empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
--embd-separator separator of embendings (default \n) for example "<#sep#>"
server:
--host HOST ip address to listen (default: 127.0.0.1)
--port PORT port to listen (default: 8080)
--path PATH path to serve static files from (default: )
--embedding(s) enable embedding endpoint (default: disabled)
--api-key KEY API key to use for authentication (default: none)
--api-key-file FNAME path to file containing API keys (default: none)
--ssl-key-file FNAME path to file a PEM-encoded SSL private key
--ssl-cert-file FNAME path to file a PEM-encoded SSL certificate
--timeout N server read/write timeout in seconds (default: 600)
--threads-http N number of threads used to process HTTP requests (default: -1)
--system-prompt-file FNAME
set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications
--log-format {text,json}
log output format: json or text (default: json)
--metrics enable prometheus compatible metrics endpoint (default: disabled)
--no-slots disables slots monitoring endpoint (default: enabled)
--slot-save-path PATH path to save slot kv cache (default: disabled)
--chat-template JINJA_TEMPLATE
set custom jinja chat template (default: template taken from model's metadata)
only commonly used templates are accepted:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
-sps, --slot-prompt-similarity SIMILARITY
how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
logging:
--simple-io use basic IO for better compatibility in subprocesses and limited consoles
-ld, --logdir LOGDIR path under which to save YAML logs (no logging if unset)
--log-test Run simple logging test
--log-disable Disable trace logs
--log-enable Enable trace logs
--log-file FNAME Specify a log filename (without extension)
--log-new Create a separate new log file on start. Each log file will have unique name: "<name>.<ID>.log"
--log-append Don't truncate the old log file.
cvector:
-o, --output FNAME output file (default: 'control_vector.gguf')
--positive-file FNAME positive prompts file, one prompt per line (default: 'examples/cvector-generator/positive.txt')
--negative-file FNAME negative prompts file, one prompt per line (default: 'examples/cvector-generator/negative.txt')
--pca-batch N batch size used for PCA. Larger batch runs faster, but uses more memory (default: 100)
--pca-iter N number of iterations used for PCA (default: 1000)
--method {pca,mean} dimensionality reduction method to be used (default: pca)
```
**If compiled with `LLAMA_SERVER_SSL=ON`**
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate
## Build

View File

@ -14,7 +14,9 @@
#include "ggml-aarch64.h"
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Woverlength-strings"
#endif
#define UNUSED GGML_UNUSED

View File

@ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
&& src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

View File

@ -291,29 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * x[i];
}
static void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (item_ct1.get_group(0) < ne02) { // src0
int offset_src =
nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx + item_ct1.get_group(1) * ne0 +
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = y[offset_src];
}
}
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
@ -1347,20 +1324,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k,
});
}
static void concat_f32_sycl(const float *x, const float *y, float *dst,
const int ne0, int ne1, int ne2, int ne02,
queue_ptr stream) {
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32(x, y, dst, ne0, ne02, item_ct1);
});
}
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
@ -2429,28 +2392,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor
(void) src1_dd;
}
inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
#pragma message("TODO: generalize concat kernel for dim != 2")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
int dim = dst->op_params[0];
GGML_ASSERT(dim == 2);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_sycl(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
}
(void) src1;
(void) dst;
}
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
@ -3359,12 +3300,6 @@ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_ten
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
@ -4101,7 +4036,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
func = ggml_sycl_group_norm;
break;
case GGML_OP_CONCAT:
func = ggml_sycl_concat;
func = ggml_sycl_op_concat;
break;
case GGML_OP_UPSCALE:
func = ggml_sycl_upscale;

View File

@ -13,6 +13,7 @@
#ifndef GGML_SYCL_BACKEND_HPP
#define GGML_SYCL_BACKEND_HPP
#include "concat.hpp"
#include "common.hpp"
#include "convert.hpp"
#include "dequantize.hpp"

View File

@ -0,0 +1,195 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include "concat.hpp"
#include "common.hpp"
static void concat_f32_dim0(const float *x, const float *y, float *dst,
const int ne0, const int ne00,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (nidx < ne00) { // src0
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
dst[offset_dst] = y[offset_src];
}
}
static void concat_f32_dim1(const float *x, const float *y, float *dst,
const int ne0, const int ne01,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (item_ct1.get_group(1) < ne01) { // src0
int offset_src =
nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx + (item_ct1.get_group(1) - ne01) * ne0 +
item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
dst[offset_dst] = y[offset_src];
}
}
static void concat_f32_dim2(const float *x, const float *y, float *dst,
const int ne0, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (item_ct1.get_group(0) < ne02) { // src0
int offset_src = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx + item_ct1.get_group(1) * ne0 +
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = y[offset_src];
}
}
static void concat_f32_sycl(const float *x, const float *y, float *dst,
int ne00, int ne01, int ne02, int ne0, int ne1,
int ne2, int dim, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks);
switch (dim) {
case 0:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
});
break;
case 1:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
});
break;
default:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
});
break;
}
}
// non-contiguous kernel (slow)
static void concat_f32_sycl_non_cont(
queue_ptr stream, const char *src0, const char *src1, char *dst,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
uint64_t nb3, int32_t dim) {
sycl::range<3> gridDim(ne3, ne2, ne1);
stream->parallel_for(
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
int64_t i3 = item_ct1.get_group(0);
int64_t i2 = item_ct1.get_group(1);
int64_t i1 = item_ct1.get_group(2);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
const float *x;
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
i0 += item_ct1.get_local_range(2)) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
(i0)*nb00);
} else {
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
}
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
*y = *x;
}
});
}
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst) {
queue_ptr stream = ctx.stream();
const int32_t dim = ((int32_t *)dst->op_params)[0];
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const float *src0_d = (const float *)src0->data;
const float *src1_d = (const float *)src1->data;
float *dst_d = (float *)dst->data;
if (dim != 3) {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_sycl(
src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
} else {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
}
} else
concat_f32_sycl_non_cont(
stream, (const char *)src0->data, (const char *)src1->data,
(char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
}

View File

@ -0,0 +1,21 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_CONCAT_HPP
#define GGML_SYCL_CONCAT_HPP
#include "common.hpp"
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst);
#endif // GGML_SYCL_CONCAT_HPP

View File

@ -6561,7 +6561,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
}
std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
@ -6645,7 +6645,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src0->ne[3]; i3++) {
for (int i2 = 0; i2 < src0->ne[2]; i2++) {
const int idx = i3*src0->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
}
}
@ -6658,7 +6658,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src0_size >= buffer_gpu->size) {
src0_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size);
ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
@ -6687,7 +6687,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src1->ne[3]; i3++) {
for (int i2 = 0; i2 < src1->ne[2]; i2++) {
const int idx = i3*src1->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
}
}
@ -6700,7 +6700,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src1_size >= buffer_gpu->size) {
src1_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size);
ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
@ -6745,7 +6745,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
const int idx = i3*src2->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
}
}
@ -6758,7 +6758,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src2_size >= buffer_gpu->size) {
src2_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src2_clone->data, src2_size);
ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
@ -6922,7 +6922,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor *
tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs);
}
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
}
float first_error_result = -1.0f;

View File

@ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
fprintf(fp, "digraph G {\n");
fprintf(fp, " newrank = true;\n");
fprintf(fp, " rankdir = LR;\n");
fprintf(fp, " rankdir = TB;\n");
for (int i = 0; i < gb->n_nodes; i++) {
struct ggml_tensor * node = gb->nodes[i];
@ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
}
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
if (ggml_nelements(node) < 5) {
if (ggml_nelements(node) < 5 && node->data != NULL) {
fprintf(fp, " | (");
for (int j = 0; j < ggml_nelements(node); j++) {
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {

View File

@ -270,10 +270,10 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "2"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
}));
}
}

View File

@ -19,6 +19,7 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
class Keys:
class General:
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
@ -120,11 +121,20 @@ class Keys:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
#
# recommended mapping of model tensor names for storage in gguf
#
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()

View File

@ -424,6 +424,9 @@ class GGUFWriter:
fout.close()
self.fout = None
def add_type(self, type_name: str) -> None:
self.add_string(Keys.General.TYPE, type_name)
def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch)

View File

@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
osize *= dim
out = np.empty(shape=osize, dtype=otype)
# compute over groups of 16 rows (arbitrary, but seems good for performance)
n_groups = rows.shape[0] // 16
n_groups = (rows.shape[0] // 16) or 1
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
return out.reshape(oshape)

View File

@ -411,6 +411,9 @@ extern "C" {
const char * content;
} llama_chat_message;
// lora adapter
struct llama_lora_adapter;
// Helpers for getting default parameters
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
@ -510,18 +513,28 @@ extern "C" {
const char * fname_out,
const llama_model_quantize_params * params);
// Apply a LoRA adapter to a loaded model
// path_base_model is the path to a higher quality model to use as a base for
// the layers modified by the adapter. Can be NULL to use the current loaded model.
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
// will be applied on top of the previous one
// Returns 0 on success
LLAMA_API int32_t llama_model_apply_lora_from_file(
const struct llama_model * model,
const char * path_lora,
float scale,
const char * path_base_model,
int32_t n_threads);
// Load a LoRA adapter from file
// The loaded adapter will be associated to the given model, and will be free when the model is deleted
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
struct llama_model * model,
const char * path_lora);
// Add a loaded LoRA adapter to given context
// This will not modify model's weight
LLAMA_API int32_t llama_lora_adapter_set(
struct llama_context * ctx,
struct llama_lora_adapter * adapter,
float scale);
// Remove a LoRA adapter from given context
// Return -1 if the adapter is not present in the context
LLAMA_API int32_t llama_lora_adapter_remove(
struct llama_context * ctx,
struct llama_lora_adapter * adapter);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector.

View File

@ -9,3 +9,4 @@
-r ./requirements/requirements-convert_hf_to_gguf.txt
-r ./requirements/requirements-convert_hf_to_gguf_update.txt
-r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
-r ./requirements/requirements-convert_lora_to_gguf.txt

View File

@ -0,0 +1,2 @@
-r ./requirements-convert_hf_to_gguf.txt
--extra-index-url https://download.pytorch.org/whl/cpu

View File

@ -1,2 +1,3 @@
docstring_parser~=0.15
pydantic~=2.6.3
requests

File diff suppressed because it is too large Load Diff