mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
c2101a2e90
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
456 lines
17 KiB
Python
456 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import struct
|
|
import tempfile
|
|
from enum import Enum, auto
|
|
from io import BufferedWriter
|
|
from typing import IO, Any, Sequence
|
|
|
|
import numpy as np
|
|
|
|
from .constants import (
|
|
GGUF_DEFAULT_ALIGNMENT,
|
|
GGUF_MAGIC,
|
|
GGUF_VERSION,
|
|
GGMLQuantizationType,
|
|
GGUFEndian,
|
|
GGUFValueType,
|
|
Keys,
|
|
RopeScalingType,
|
|
PoolingType,
|
|
TokenType,
|
|
)
|
|
|
|
|
|
class WriterState(Enum):
|
|
EMPTY = auto()
|
|
HEADER = auto()
|
|
KV_DATA = auto()
|
|
TI_DATA = auto()
|
|
|
|
|
|
class GGUFWriter:
|
|
fout: BufferedWriter
|
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
|
tensors: list[np.ndarray[Any, Any]]
|
|
_simple_value_packing = {
|
|
GGUFValueType.UINT8: "B",
|
|
GGUFValueType.INT8: "b",
|
|
GGUFValueType.UINT16: "H",
|
|
GGUFValueType.INT16: "h",
|
|
GGUFValueType.UINT32: "I",
|
|
GGUFValueType.INT32: "i",
|
|
GGUFValueType.FLOAT32: "f",
|
|
GGUFValueType.UINT64: "Q",
|
|
GGUFValueType.INT64: "q",
|
|
GGUFValueType.FLOAT64: "d",
|
|
GGUFValueType.BOOL: "?",
|
|
}
|
|
|
|
def __init__(
|
|
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
|
|
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
|
):
|
|
self.fout = open(path, "wb")
|
|
self.arch = arch
|
|
self.endianess = endianess
|
|
self.offset_tensor = 0
|
|
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
|
self.kv_data = bytearray()
|
|
self.kv_data_count = 0
|
|
self.ti_data = bytearray()
|
|
self.ti_data_count = 0
|
|
self.use_temp_file = use_temp_file
|
|
self.temp_file = None
|
|
self.tensors = []
|
|
print("gguf: This GGUF file is for {0} Endian only".format(
|
|
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
|
))
|
|
self.state = WriterState.EMPTY
|
|
|
|
self.add_architecture()
|
|
|
|
def write_header_to_file(self) -> None:
|
|
if self.state is not WriterState.EMPTY:
|
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
|
|
|
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
|
self._write_packed("I", GGUF_VERSION)
|
|
self._write_packed("Q", self.ti_data_count)
|
|
self._write_packed("Q", self.kv_data_count)
|
|
self.flush()
|
|
self.state = WriterState.HEADER
|
|
|
|
def write_kv_data_to_file(self) -> None:
|
|
if self.state is not WriterState.HEADER:
|
|
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
|
|
|
self.fout.write(self.kv_data)
|
|
self.flush()
|
|
self.state = WriterState.KV_DATA
|
|
|
|
def write_ti_data_to_file(self) -> None:
|
|
if self.state is not WriterState.KV_DATA:
|
|
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
|
|
|
self.fout.write(self.ti_data)
|
|
self.flush()
|
|
self.state = WriterState.TI_DATA
|
|
|
|
def add_key(self, key: str) -> None:
|
|
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
|
|
|
|
def add_uint8(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.UINT8)
|
|
|
|
def add_int8(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.INT8)
|
|
|
|
def add_uint16(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.UINT16)
|
|
|
|
def add_int16(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.INT16)
|
|
|
|
def add_uint32(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.UINT32)
|
|
|
|
def add_int32(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.INT32)
|
|
|
|
def add_float32(self, key: str, val: float) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.FLOAT32)
|
|
|
|
def add_uint64(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.UINT64)
|
|
|
|
def add_int64(self, key: str, val: int) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.INT64)
|
|
|
|
def add_float64(self, key: str, val: float) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.FLOAT64)
|
|
|
|
def add_bool(self, key: str, val: bool) -> None:
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.BOOL)
|
|
|
|
def add_string(self, key: str, val: str) -> None:
|
|
if not val:
|
|
return
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.STRING)
|
|
|
|
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
|
if not isinstance(val, Sequence):
|
|
raise ValueError("Value must be a sequence for array type")
|
|
|
|
self.add_key(key)
|
|
self.add_val(val, GGUFValueType.ARRAY)
|
|
|
|
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
|
|
if vtype is None:
|
|
vtype = GGUFValueType.get_type(val)
|
|
|
|
if add_vtype:
|
|
self.kv_data += self._pack("I", vtype)
|
|
self.kv_data_count += 1
|
|
|
|
pack_fmt = self._simple_value_packing.get(vtype)
|
|
if pack_fmt is not None:
|
|
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
|
elif vtype == GGUFValueType.STRING:
|
|
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
|
self.kv_data += self._pack("Q", len(encoded_val))
|
|
self.kv_data += encoded_val
|
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
|
ltype = GGUFValueType.get_type(val[0])
|
|
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
|
raise ValueError("All items in a GGUF array should be of the same type")
|
|
self.kv_data += self._pack("I", ltype)
|
|
self.kv_data += self._pack("Q", len(val))
|
|
for item in val:
|
|
self.add_val(item, add_vtype=False)
|
|
else:
|
|
raise ValueError("Invalid GGUF metadata value type or value")
|
|
|
|
@staticmethod
|
|
def ggml_pad(x: int, n: int) -> int:
|
|
return ((x + n - 1) // n) * n
|
|
|
|
def add_tensor_info(
|
|
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
|
|
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
|
) -> None:
|
|
if self.state is not WriterState.EMPTY:
|
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
|
|
|
if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
|
|
raise ValueError("Only F32 and F16 tensors are supported for now")
|
|
|
|
encoded_name = name.encode("utf8")
|
|
self.ti_data += self._pack("Q", len(encoded_name))
|
|
self.ti_data += encoded_name
|
|
n_dims = len(tensor_shape)
|
|
self.ti_data += self._pack("I", n_dims)
|
|
for i in range(n_dims):
|
|
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
|
if raw_dtype is None:
|
|
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
|
|
else:
|
|
dtype = raw_dtype
|
|
self.ti_data += self._pack("I", dtype)
|
|
self.ti_data += self._pack("Q", self.offset_tensor)
|
|
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
|
self.ti_data_count += 1
|
|
|
|
def add_tensor(
|
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
|
raw_dtype: GGMLQuantizationType | None = None,
|
|
) -> None:
|
|
if self.endianess == GGUFEndian.BIG:
|
|
tensor.byteswap(inplace=True)
|
|
if self.use_temp_file and self.temp_file is None:
|
|
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
|
|
fp.seek(0)
|
|
self.temp_file = fp
|
|
|
|
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
|
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
|
|
|
if self.temp_file is None:
|
|
self.tensors.append(tensor)
|
|
return
|
|
|
|
tensor.tofile(self.temp_file)
|
|
self.write_padding(self.temp_file, tensor.nbytes)
|
|
|
|
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
|
|
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
|
if pad != 0:
|
|
fp.write(bytes([0] * pad))
|
|
|
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
|
if self.state is not WriterState.TI_DATA:
|
|
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
|
|
|
if self.endianess == GGUFEndian.BIG:
|
|
tensor.byteswap(inplace=True)
|
|
self.write_padding(self.fout, self.fout.tell())
|
|
tensor.tofile(self.fout)
|
|
self.write_padding(self.fout, tensor.nbytes)
|
|
|
|
def write_tensors_to_file(self) -> None:
|
|
self.write_ti_data_to_file()
|
|
|
|
self.write_padding(self.fout, self.fout.tell())
|
|
|
|
if self.temp_file is None:
|
|
while True:
|
|
try:
|
|
tensor = self.tensors.pop(0)
|
|
except IndexError:
|
|
break
|
|
tensor.tofile(self.fout)
|
|
self.write_padding(self.fout, tensor.nbytes)
|
|
return
|
|
|
|
self.temp_file.seek(0)
|
|
|
|
shutil.copyfileobj(self.temp_file, self.fout)
|
|
self.flush()
|
|
self.temp_file.close()
|
|
|
|
def flush(self) -> None:
|
|
self.fout.flush()
|
|
|
|
def close(self) -> None:
|
|
self.fout.close()
|
|
|
|
def add_architecture(self) -> None:
|
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
|
|
|
def add_author(self, author: str) -> None:
|
|
self.add_string(Keys.General.AUTHOR, author)
|
|
|
|
def add_tensor_data_layout(self, layout: str) -> None:
|
|
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
|
|
|
def add_url(self, url: str) -> None:
|
|
self.add_string(Keys.General.URL, url)
|
|
|
|
def add_description(self, description: str) -> None:
|
|
self.add_string(Keys.General.DESCRIPTION, description)
|
|
|
|
def add_source_url(self, url: str) -> None:
|
|
self.add_string(Keys.General.SOURCE_URL, url)
|
|
|
|
def add_source_hf_repo(self, repo: str) -> None:
|
|
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
|
|
|
|
def add_file_type(self, ftype: int) -> None:
|
|
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
|
|
|
def add_name(self, name: str) -> None:
|
|
self.add_string(Keys.General.NAME, name)
|
|
|
|
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
|
|
self.add_uint32(
|
|
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
|
|
|
def add_custom_alignment(self, alignment: int) -> None:
|
|
self.data_alignment = alignment
|
|
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
|
|
|
def add_context_length(self, length: int) -> None:
|
|
self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
|
|
|
|
def add_embedding_length(self, length: int) -> None:
|
|
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
|
|
|
def add_block_count(self, length: int) -> None:
|
|
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
|
|
|
def add_feed_forward_length(self, length: int) -> None:
|
|
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
|
|
|
def add_parallel_residual(self, use: bool) -> None:
|
|
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
|
|
|
def add_head_count(self, count: int) -> None:
|
|
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
|
|
|
def add_head_count_kv(self, count: int) -> None:
|
|
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
|
|
|
def add_key_length(self, length: int) -> None:
|
|
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
|
|
|
|
def add_value_length(self, length: int) -> None:
|
|
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
|
|
|
|
def add_max_alibi_bias(self, bias: float) -> None:
|
|
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
|
|
|
|
def add_clamp_kqv(self, value: float) -> None:
|
|
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
|
|
|
def add_expert_count(self, count: int) -> None:
|
|
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
|
|
|
|
def add_expert_used_count(self, count: int) -> None:
|
|
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
|
|
|
|
def add_layer_norm_eps(self, value: float) -> None:
|
|
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
|
|
|
def add_layer_norm_rms_eps(self, value: float) -> None:
|
|
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
|
|
|
def add_causal_attention(self, value: bool) -> None:
|
|
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
|
|
|
def add_pooling_type(self, value: PoolingType) -> None:
|
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
|
|
|
def add_rope_dimension_count(self, count: int) -> None:
|
|
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
|
|
|
def add_rope_freq_base(self, value: float) -> None:
|
|
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
|
|
|
|
def add_rope_scaling_type(self, value: RopeScalingType) -> None:
|
|
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
|
|
|
|
def add_rope_scaling_factor(self, value: float) -> None:
|
|
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
|
|
|
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
|
self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
|
|
|
|
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
|
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
|
|
|
def add_ssm_conv_kernel(self, value: int) -> None:
|
|
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
|
|
|
def add_ssm_inner_size(self, value: int) -> None:
|
|
self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
|
|
|
|
def add_ssm_state_size(self, value: int) -> None:
|
|
self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
|
|
|
|
def add_ssm_time_step_rank(self, value: int) -> None:
|
|
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
|
|
|
def add_tokenizer_model(self, model: str) -> None:
|
|
self.add_string(Keys.Tokenizer.MODEL, model)
|
|
|
|
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
|
self.add_array(Keys.Tokenizer.LIST, tokens)
|
|
|
|
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
|
self.add_array(Keys.Tokenizer.MERGES, merges)
|
|
|
|
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
|
|
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
|
|
|
|
def add_token_type_count(self, value: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
|
|
|
|
def add_token_scores(self, scores: Sequence[float]) -> None:
|
|
self.add_array(Keys.Tokenizer.SCORES, scores)
|
|
|
|
def add_bos_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.BOS_ID, id)
|
|
|
|
def add_eos_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.EOS_ID, id)
|
|
|
|
def add_unk_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.UNK_ID, id)
|
|
|
|
def add_sep_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.SEP_ID, id)
|
|
|
|
def add_pad_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.PAD_ID, id)
|
|
|
|
def add_cls_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.CLS_ID, id)
|
|
|
|
def add_mask_token_id(self, id: int) -> None:
|
|
self.add_uint32(Keys.Tokenizer.MASK_ID, id)
|
|
|
|
def add_add_bos_token(self, value: bool) -> None:
|
|
self.add_bool(Keys.Tokenizer.ADD_BOS, value)
|
|
|
|
def add_add_eos_token(self, value: bool) -> None:
|
|
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
|
|
|
def add_add_space_prefix(self, value: bool) -> None:
|
|
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
|
|
|
def add_chat_template(self, value: str) -> None:
|
|
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
|
|
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
|
pack_prefix = ''
|
|
if not skip_pack_prefix:
|
|
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
|
return struct.pack(f'{pack_prefix}{fmt}', value)
|
|
|
|
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
|
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|