mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
plamo convert
This commit is contained in:
parent
4c585b4c6c
commit
b2330f57e2
@ -182,6 +182,8 @@ class Model:
|
|||||||
return QwenModel
|
return QwenModel
|
||||||
if model_architecture == "MixtralForCausalLM":
|
if model_architecture == "MixtralForCausalLM":
|
||||||
return MixtralModel
|
return MixtralModel
|
||||||
|
if model_architecture == "PlamoForCausalLM":
|
||||||
|
return PlamoModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
@ -221,6 +223,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.QWEN
|
return gguf.MODEL_ARCH.QWEN
|
||||||
if arch == "MixtralForCausalLM":
|
if arch == "MixtralForCausalLM":
|
||||||
return gguf.MODEL_ARCH.LLAMA
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
|
if arch == "PlamoForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.PLAMO
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
@ -980,11 +984,72 @@ class QwenModel(Model):
|
|||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
|
class PlamoModel(Model):
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
hparams = self.hparams
|
||||||
|
block_count = hparams["num_hidden_layers"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("PLaMo")
|
||||||
|
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||||
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
if "self_attn.rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a huggingface model to a GGML compatible file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-only", action="store_true",
|
"--vocab-only", action="store_true",
|
||||||
help="extract only the vocab",
|
help="extract only the vocab",
|
||||||
|
@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
QWEN = auto()
|
QWEN = auto()
|
||||||
|
PLAMO = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
MODEL_ARCH.QWEN: "qwen",
|
MODEL_ARCH.QWEN: "qwen",
|
||||||
|
MODEL_ARCH.PLAMO: "plamo",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -347,6 +349,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.PLAMO: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.GPT2: [
|
MODEL_ARCH.GPT2: [
|
||||||
# TODO
|
# TODO
|
||||||
],
|
],
|
||||||
|
@ -75,6 +75,7 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln1", # yi
|
"model.layers.{bid}.ln1", # yi
|
||||||
|
"model.layers.layers.{bid}.norm", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
@ -94,26 +95,29 @@ class TensorNameMap:
|
|||||||
|
|
||||||
# Attention query
|
# Attention query
|
||||||
MODEL_TENSOR.ATTN_Q: (
|
MODEL_TENSOR.ATTN_Q: (
|
||||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wq", # llama-pth
|
"layers.{bid}.attention.wq", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.query", # bert
|
"encoder.layer.{bid}.attention.self.query", # bert
|
||||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
MODEL_TENSOR.ATTN_K: (
|
MODEL_TENSOR.ATTN_K: (
|
||||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wk", # llama-pth
|
"layers.{bid}.attention.wk", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.key", # bert
|
"encoder.layer.{bid}.attention.self.key", # bert
|
||||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
MODEL_TENSOR.ATTN_V: (
|
MODEL_TENSOR.ATTN_V: (
|
||||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
||||||
"layers.{bid}.attention.wv", # llama-pth
|
"layers.{bid}.attention.wv", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.value", # bert
|
"encoder.layer.{bid}.attention.self.value", # bert
|
||||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||||
|
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
@ -128,12 +132,14 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
|
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||||
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
||||||
|
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
@ -167,6 +173,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||||
"transformer.h.{bid}.mlp.w1", # qwen
|
"transformer.h.{bid}.mlp.w1", # qwen
|
||||||
|
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
@ -179,6 +186,7 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
||||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||||
"transformer.h.{bid}.mlp.w2", # qwen
|
"transformer.h.{bid}.mlp.w2", # qwen
|
||||||
|
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
@ -198,6 +206,7 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.output.dense", # bert
|
"encoder.layer.{bid}.output.dense", # bert
|
||||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
|
"model.layers.layers.{bid}.mlp.down_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
Loading…
Reference in New Issue
Block a user