mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
add back convert hf to gguf
This commit is contained in:
parent
0a81051ae2
commit
6cabdda0df
@ -17,6 +17,7 @@ from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
||||
from itertools import chain
|
||||
|
||||
from transformers import AutoConfig
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -66,6 +67,12 @@ class Model:
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
|
||||
# for vision model
|
||||
preprocessor_config: dict[str, Any] | None = None
|
||||
vparams: dict[str, Any] | None = None
|
||||
v_tensor_map: gguf.TensorNameMap
|
||||
v_tensor_names: set[str] | None
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
|
||||
@ -95,6 +102,7 @@ class Model:
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
@ -210,9 +218,13 @@ class Model:
|
||||
|
||||
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
||||
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
||||
if new_name is None:
|
||||
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
||||
if new_name is not None:
|
||||
return new_name
|
||||
elif new_name_vision is not None:
|
||||
return new_name_vision
|
||||
else:
|
||||
raise ValueError(f"Can not map tensor {name!r}")
|
||||
return new_name
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
@ -466,7 +478,24 @@ class Model:
|
||||
@staticmethod
|
||||
def load_hparams(dir_model: Path):
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
hparams = json.load(f)
|
||||
if "text_config" in hparams:
|
||||
text_config = hparams["text_config"]
|
||||
# for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID
|
||||
if "_name_or_path" in text_config:
|
||||
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
|
||||
hparams = {**text_config, **hparams}
|
||||
return hparams
|
||||
|
||||
@staticmethod
|
||||
def load_preprocessor_config(dir_model: Path):
|
||||
# TODO: this varies vastly among models, need to handle more cases in the future
|
||||
file_path = dir_model / "preprocessor_config.json"
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||
@ -1557,10 +1586,17 @@ class StableLMModel(Model):
|
||||
raise ValueError(f"Unprocessed norms: {norms}")
|
||||
|
||||
|
||||
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
|
||||
class LlamaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if "vision_config" in self.hparams:
|
||||
self.vparams = self.hparams["vision_config"]
|
||||
if self.vparams is not None:
|
||||
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
@ -1594,6 +1630,26 @@ class LlamaModel(Model):
|
||||
if self.hparams.get("vocab_size", 32000) == 49152:
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
# For vision model
|
||||
if self.vparams is not None and self.preprocessor_config is not None:
|
||||
self.gguf_writer.add_vision_type("clip-vit")
|
||||
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
|
||||
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
|
||||
self.gguf_writer.add_vision_clip_architecture("llava")
|
||||
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
|
||||
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
|
||||
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
|
||||
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
|
||||
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
|
||||
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
|
||||
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
|
||||
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
|
||||
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
|
||||
# TODO: should not hardcode these, but they are currently missing from config.json
|
||||
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
|
||||
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
@ -1624,6 +1680,12 @@ class LlamaModel(Model):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
# For vision model
|
||||
if name.startswith("language_model"):
|
||||
name = name.replace("language_model.", "")
|
||||
if "post_layernorm" in name:
|
||||
return [] # skip post_layernorm
|
||||
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||
|
@ -2949,6 +2949,7 @@ struct server_context {
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
nullptr,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
@ -202,6 +202,9 @@ class Keys:
|
||||
FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id"
|
||||
FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id"
|
||||
FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id"
|
||||
# Vision models
|
||||
IMAGE_START_ID = "tokenizer.ggml.image_start_token_id"
|
||||
IMAGE_END_ID = "tokenizer.ggml.image_end_token_id"
|
||||
# deprecated:
|
||||
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
|
||||
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
|
||||
@ -211,6 +214,31 @@ class Keys:
|
||||
TYPE = "adapter.type"
|
||||
LORA_ALPHA = "adapter.lora.alpha"
|
||||
|
||||
class Vision:
|
||||
# only support vision.type = "clip-vit" for now
|
||||
TYPE = "vision.type"
|
||||
IMAGE_SIZE = "vision.image_size"
|
||||
PATCH_SIZE = "vision.patch_size"
|
||||
IMAGE_MEAN = "vision.image_mean"
|
||||
IMAGE_STD = "vision.image_std"
|
||||
|
||||
class Clip:
|
||||
ARCHITECTURE = "vision.clip.architecture"
|
||||
CONTEXT_LENGTH = "vision.clip.context_length"
|
||||
EMBEDDING_LENGTH = "vision.clip.embedding_length"
|
||||
BLOCK_COUNT = "vision.clip.block_count"
|
||||
FEED_FORWARD_LENGTH = "vision.clip.feed_forward_length"
|
||||
PROJECTION_TYPE = "vision.clip.projection_type"
|
||||
PROJECTION_DIM = "vision.clip.projection_dim"
|
||||
USE_GELU = "vision.clip.use_gelu"
|
||||
MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings"
|
||||
MAX_SLICES = "vision.clip.max_slices"
|
||||
PROJECTOR_TYPE = "vision.clip.projector_type"
|
||||
SELECT_LAYER = "vision.clip.select_layer"
|
||||
PATCH_MERGE_TYPE = "vision.clip.patch_merge_type"
|
||||
HEAD_COUNT = "vision.clip.attention.head_count"
|
||||
LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon"
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
@ -279,6 +307,8 @@ class MODEL_ARCH(IntEnum):
|
||||
GRANITE_MOE = auto()
|
||||
CHAMELEON = auto()
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
# vision models
|
||||
LLAVA_VISION = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -390,6 +420,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ENC_OUTPUT_NORM = auto()
|
||||
CLS = auto() # classifier
|
||||
CLS_OUT = auto() # classifier output projection
|
||||
# wavtokenizer
|
||||
CONV1D = auto()
|
||||
CONVNEXT_DW = auto()
|
||||
CONVNEXT_NORM = auto()
|
||||
@ -406,6 +437,21 @@ class MODEL_TENSOR(IntEnum):
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_ENC_EMBD_CLS = auto()
|
||||
V_ENC_EMBD_PATCH = auto()
|
||||
V_ENC_EMBD_POS = auto()
|
||||
V_ENC_ATTN_Q = auto()
|
||||
V_ENC_ATTN_K = auto()
|
||||
V_ENC_ATTN_V = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_OUTPUT = auto()
|
||||
V_ENC_OUTPUT_NORM = auto()
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
V_PRE_NORM = auto()
|
||||
V_POST_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -593,6 +639,21 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}",
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch",
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm",
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
|
||||
MODEL_TENSOR.V_POST_NORM: "v.post_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -1534,6 +1595,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_V,
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT,
|
||||
],
|
||||
MODEL_ARCH.LLAVA_VISION: [
|
||||
MODEL_TENSOR.V_MMPROJ,
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K,
|
||||
MODEL_TENSOR.V_ENC_ATTN_V,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.V_PRE_NORM,
|
||||
MODEL_TENSOR.V_POST_NORM,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@ -1615,6 +1692,15 @@ class PoolingType(IntEnum):
|
||||
CLS = 2
|
||||
|
||||
|
||||
class CLIPProjectorType(Enum):
|
||||
MLP = 'mlp'
|
||||
|
||||
|
||||
class CLIPPatchMergeType(Enum):
|
||||
FLAT = 'flat'
|
||||
SPATIAL_UNPAD = 'spatial_unpad'
|
||||
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
|
@ -27,6 +27,8 @@ from .constants import (
|
||||
PoolingType,
|
||||
TokenType,
|
||||
ExpertGatingFuncType,
|
||||
CLIPPatchMergeType,
|
||||
CLIPProjectorType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
@ -874,6 +876,57 @@ class GGUFWriter:
|
||||
|
||||
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
|
||||
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
|
||||
|
||||
def add_vision_type(self, value: str) -> None:
|
||||
self.add_string(Keys.Vision.TYPE, value)
|
||||
|
||||
def add_vision_image_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.IMAGE_SIZE, value)
|
||||
|
||||
def add_vision_patch_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.PATCH_SIZE, value)
|
||||
|
||||
def add_vision_clip_architecture(self, value: str) -> None:
|
||||
self.add_string(Keys.Vision.Clip.ARCHITECTURE, value)
|
||||
|
||||
def add_vision_clip_context_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.CONTEXT_LENGTH, value)
|
||||
|
||||
def add_vision_clip_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_vision_clip_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.BLOCK_COUNT, value)
|
||||
|
||||
def add_vision_clip_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_vision_clip_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value)
|
||||
|
||||
def add_vision_clip_max_position_embeddings(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value)
|
||||
|
||||
def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None:
|
||||
self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value)
|
||||
|
||||
def add_vision_clip_max_slices(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.MAX_SLICES, value)
|
||||
|
||||
def add_vision_clip_select_layer(self, value: int) -> None:
|
||||
self.add_int32(Keys.Vision.Clip.SELECT_LAYER, value)
|
||||
|
||||
def add_vision_clip_patch_merge_type(self, value: CLIPPatchMergeType) -> None:
|
||||
self.add_string(Keys.Vision.Clip.PATCH_MERGE_TYPE, value.value)
|
||||
|
||||
def add_vision_clip_layer_norm_epsilon(self, value: float) -> None:
|
||||
self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value)
|
||||
|
||||
def add_vision_clip_image_mean(self, value: Sequence[float]) -> None:
|
||||
self.add_array(Keys.Vision.IMAGE_MEAN, value)
|
||||
|
||||
def add_vision_clip_image_std(self, value: Sequence[float]) -> None:
|
||||
self.add_array(Keys.Vision.IMAGE_STD, value)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
if not isinstance(value, str):
|
||||
|
@ -787,6 +787,64 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: (
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
|
||||
#############################################################################
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ: (
|
||||
"multi_modal_projector.linear_{bid}",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: (
|
||||
"vision_tower.vision_model.embeddings.class_embedding",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
"vision_tower.vision_model.embeddings.patch_embedding",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
"vision_tower.vision_model.embeddings.position_embedding",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_PRE_NORM: (
|
||||
"vision_tower.vision_model.pre_layrnorm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_POST_NORM: (
|
||||
"vision_tower.vision_model.post_layernorm",
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
@ -1292,7 +1292,7 @@ extern "C" {
|
||||
|
||||
// Encode patches into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// Model split
|
||||
|
@ -40,7 +40,7 @@ struct clip_hparams {
|
||||
std::array<float, 3> image_mean;
|
||||
std::array<float, 3> image_std;
|
||||
|
||||
std::array<int32_t, 32> image_grid_pinpoints;
|
||||
std::array<int32_t, 32> image_grid_pinpoints; // TODO: should this be array of (x, y) pairs?
|
||||
int32_t image_crop_resolution;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user