mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-23 21:17:54 +01:00
llm : support Adept Persimmon 8B (#3410)
* Produces garbage output * wip: correct tensors up to RoPE * correct tensors thru RoPE * Correct outputs through masked & softmax'd KQ * fp32 works * Rename adept->persimmon * Produces correct outputs * clean up convert scripts * remove printing logic from ggml.c * remove prints from llama.cpp & fix merge * trivial cleanups * Add offload funcs * update conversion script to directly take adept artifacts rather than .saftensors file * Fix norm eps bug * Support sqr and concat on metal, persimmon-8b-q4 runs correctly * Small changes from review * Formatting changes * Minor changes to conversion script * Remove old script * Fix editorconfig formatting * Fix build * add overlooked offload code ggml-ci
This commit is contained in:
parent
3a716b4dae
commit
0e797c2fc5
130
convert-persimmon-to-gguf.py
Normal file
130
convert-persimmon-to-gguf.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import os
|
||||
from pprint import pprint
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||
import gguf
|
||||
|
||||
def _flatten_dict(dct, tensors, prefix=None):
|
||||
assert isinstance(dct, dict)
|
||||
for key in dct.keys():
|
||||
new_prefix = prefix + '.' + key if prefix is not None else key
|
||||
if isinstance(dct[key], torch.Tensor):
|
||||
tensors[new_prefix] = dct[key]
|
||||
elif isinstance(dct[key], dict):
|
||||
_flatten_dict(dct[key], tensors, new_prefix)
|
||||
else:
|
||||
raise ValueError(type(dct[key]))
|
||||
return None
|
||||
|
||||
def _get_sentencepiece_tokenizer_info(dir_model: Path):
|
||||
tokenizer_path = dir_model / 'adept_vocab.model'
|
||||
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
|
||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||
print('gguf: adding tokens')
|
||||
tokens: list[bytes] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
text: bytes
|
||||
score: float
|
||||
|
||||
piece = tokenizer.id_to_piece(i)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.get_score(i)
|
||||
|
||||
toktype = 1
|
||||
if tokenizer.is_unknown(i):
|
||||
toktype = 2
|
||||
if tokenizer.is_control(i):
|
||||
toktype = 3
|
||||
if tokenizer.is_unused(i):
|
||||
toktype = 5
|
||||
if tokenizer.is_byte(i):
|
||||
toktype = 6
|
||||
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
pass
|
||||
return tokens, scores, toktypes
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
|
||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
|
||||
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
|
||||
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
|
||||
args = parser.parse_args()
|
||||
sys.path.append(str(args.adept_inference_dir))
|
||||
persimmon_model = torch.load(args.ckpt_path)
|
||||
hparams = persimmon_model['args']
|
||||
pprint(hparams)
|
||||
tensors = {}
|
||||
_flatten_dict(persimmon_model['model'], tensors, None)
|
||||
|
||||
arch = gguf.MODEL_ARCH.PERSIMMON
|
||||
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])
|
||||
|
||||
block_count = hparams.num_layers
|
||||
head_count = hparams.num_attention_heads
|
||||
head_count_kv = head_count
|
||||
ctx_length = hparams.seq_length
|
||||
hidden_size = hparams.hidden_size
|
||||
|
||||
gguf_writer.add_name('persimmon-8b-chat')
|
||||
gguf_writer.add_context_length(ctx_length)
|
||||
gguf_writer.add_embedding_length(hidden_size)
|
||||
gguf_writer.add_block_count(block_count)
|
||||
gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
|
||||
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
|
||||
gguf_writer.add_head_count(head_count)
|
||||
gguf_writer.add_head_count_kv(head_count_kv)
|
||||
gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
|
||||
gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)
|
||||
|
||||
tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
|
||||
gguf_writer.add_tokenizer_model('llama')
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
gguf_writer.add_token_types(toktypes)
|
||||
gguf_writer.add_bos_token_id(71013)
|
||||
gguf_writer.add_eos_token_id(71013)
|
||||
|
||||
tensor_map = gguf.get_tensor_name_map(arch, block_count)
|
||||
print(tensor_map)
|
||||
for name in tensors.keys():
|
||||
data = tensors[name]
|
||||
if name.endswith(".self_attention.rotary_emb.inv_freq"):
|
||||
continue
|
||||
old_dtype = data.dtype
|
||||
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
|
||||
data = data.to(torch.float32).squeeze().numpy()
|
||||
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print("Can not map tensor '" + name + "'")
|
||||
sys.exit()
|
||||
n_dims = len(data.shape)
|
||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||
gguf_writer.add_tensor(new_name, data)
|
||||
print("gguf: write header")
|
||||
gguf_writer.write_header_to_file()
|
||||
print("gguf: write metadata")
|
||||
gguf_writer.write_kv_data_to_file()
|
||||
print("gguf: write tensors")
|
||||
gguf_writer.write_tensors_to_file()
|
||||
|
||||
gguf_writer.close()
|
||||
|
||||
print(f"gguf: model successfully exported to '{args.outfile}'")
|
||||
print("")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
54
ggml-metal.m
54
ggml-metal.m
@ -109,6 +109,8 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(concat);
|
||||
GGML_METAL_DECL_KERNEL(sqr);
|
||||
|
||||
#undef GGML_METAL_DECL_KERNEL
|
||||
};
|
||||
@ -300,6 +302,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(concat);
|
||||
GGML_METAL_ADD_KERNEL(sqr);
|
||||
|
||||
#undef GGML_METAL_ADD_KERNEL
|
||||
}
|
||||
@ -375,6 +379,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(concat);
|
||||
GGML_METAL_DEL_KERNEL(sqr);
|
||||
|
||||
#undef GGML_METAL_DEL_KERNEL
|
||||
|
||||
@ -766,6 +772,43 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
// noop
|
||||
} break;
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
|
||||
int64_t nb = ne00;
|
||||
[encoder setComputePipelineState:ctx->pipeline_concat];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
@ -903,6 +946,17 @@ void ggml_metal_graph_compute(
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SQR:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
const int nth = MIN(32, ne00);
|
||||
|
@ -132,6 +132,13 @@ kernel void kernel_relu(
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sqr(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
@ -1098,6 +1105,62 @@ kernel void kernel_cpy_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_concat(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
if (i02 < ne02) {
|
||||
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
||||
src0_ptr += ntg.x*nb00;
|
||||
} else {
|
||||
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
||||
src1_ptr += ntg.x*nb10;
|
||||
}
|
||||
dst_ptr += ntg.x*nb0;
|
||||
}
|
||||
}
|
||||
|
||||
//============================================ k-quants ======================================================
|
||||
|
||||
#ifndef QK_K
|
||||
|
@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GPTNEOX : int = auto()
|
||||
MPT : int = auto()
|
||||
STARCODER : int = auto()
|
||||
PERSIMMON : int = auto()
|
||||
REFACT : int = auto()
|
||||
BERT : int = auto()
|
||||
|
||||
@ -108,6 +109,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_DOWN : int = auto()
|
||||
FFN_UP : int = auto()
|
||||
FFN_NORM : int = auto()
|
||||
ATTN_Q_NORM : int = auto()
|
||||
ATTN_K_NORM : int = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -119,6 +122,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GPTNEOX: "gptneox",
|
||||
MODEL_ARCH.MPT: "mpt",
|
||||
MODEL_ARCH.STARCODER: "starcoder",
|
||||
MODEL_ARCH.PERSIMMON: "persimmon",
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
}
|
||||
@ -130,7 +134,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
@ -139,6 +142,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
@ -249,6 +254,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.PERSIMMON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.REFACT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@ -279,6 +298,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.PERSIMMON: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@ -286,12 +308,13 @@ class TensorNameMap:
|
||||
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
# Token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD: (
|
||||
"gpt_neox.embed_in", # gptneox
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact
|
||||
"transformer.word_embeddings", # falcon
|
||||
"model.embed_tokens", # llama-hf
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert
|
||||
"gpt_neox.embed_in", # gptneox
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact
|
||||
"transformer.word_embeddings", # falcon
|
||||
"model.embed_tokens", # llama-hf
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@ -307,20 +330,22 @@ class TensorNameMap:
|
||||
|
||||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 gpt-j mpt falcon llama-hf baichuan
|
||||
"output", # llama-pth
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan
|
||||
"output", # llama-pth
|
||||
"word_embeddings_for_head", # persimmon
|
||||
),
|
||||
|
||||
# Output norm
|
||||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon
|
||||
"model.norm", # llama-hf baichuan
|
||||
"norm", # llama-pth
|
||||
"embeddings.LayerNorm", # bert
|
||||
"transformer.norm_f", # mpt
|
||||
"ln_f", # refact
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon
|
||||
"model.norm", # llama-hf baichuan
|
||||
"norm", # llama-pth
|
||||
"embeddings.LayerNorm", # bert
|
||||
"transformer.norm_f", # mpt
|
||||
"ln_f", # refact
|
||||
"language_model.encoder.final_layernorm", # persimmon
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@ -332,14 +357,15 @@ class TensorNameMap:
|
||||
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
# Attention norm
|
||||
MODEL_TENSOR.ATTN_NORM: (
|
||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
|
||||
"transformer.blocks.{bid}.norm_1", # mpt
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
|
||||
"transformer.blocks.{bid}.norm_1", # mpt
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@ -349,10 +375,11 @@ class TensorNameMap:
|
||||
|
||||
# Attention query-key-value
|
||||
MODEL_TENSOR.ATTN_QKV: (
|
||||
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
||||
"transformer.h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
||||
"transformer.h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@ -381,14 +408,15 @@ class TensorNameMap:
|
||||
|
||||
# Attention output
|
||||
MODEL_TENSOR.ATTN_OUT: (
|
||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
@ -399,24 +427,26 @@ class TensorNameMap:
|
||||
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
MODEL_TENSOR.FFN_UP: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_fc", # gpt2
|
||||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_fc", # gpt2
|
||||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
),
|
||||
|
||||
# Feed-forward gate
|
||||
@ -427,15 +457,28 @@ class TensorNameMap:
|
||||
|
||||
# Feed-forward down
|
||||
MODEL_TENSOR.FFN_DOWN: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||
)
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
522
llama.cpp
522
llama.cpp
@ -186,6 +186,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GPTNEOX,
|
||||
LLM_ARCH_MPT,
|
||||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_PERSIMMON,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
@ -199,6 +200,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MPT, "mpt" },
|
||||
{ LLM_ARCH_BAICHUAN, "baichuan" },
|
||||
{ LLM_ARCH_STARCODER, "starcoder" },
|
||||
{ LLM_ARCH_PERSIMMON, "persimmon" },
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
};
|
||||
|
||||
@ -318,6 +320,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
};
|
||||
|
||||
static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
@ -399,6 +403,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PERSIMMON,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd"},
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm"},
|
||||
{ LLM_TENSOR_OUTPUT, "output"},
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"},
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"},
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"},
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"},
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"},
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"},
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"},
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"},
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"},
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MPT,
|
||||
{
|
||||
@ -959,6 +980,7 @@ enum e_model {
|
||||
MODEL_1B,
|
||||
MODEL_3B,
|
||||
MODEL_7B,
|
||||
MODEL_8B,
|
||||
MODEL_13B,
|
||||
MODEL_15B,
|
||||
MODEL_30B,
|
||||
@ -1041,6 +1063,10 @@ struct llama_layer {
|
||||
struct ggml_tensor * attn_norm_b;
|
||||
struct ggml_tensor * attn_norm_2;
|
||||
struct ggml_tensor * attn_norm_2_b;
|
||||
struct ggml_tensor * attn_q_norm;
|
||||
struct ggml_tensor * attn_q_norm_b;
|
||||
struct ggml_tensor * attn_k_norm;
|
||||
struct ggml_tensor * attn_k_norm_b;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
@ -1901,6 +1927,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_1B: return "1B";
|
||||
case MODEL_3B: return "3B";
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_8B: return "8B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_15B: return "15B";
|
||||
case MODEL_30B: return "30B";
|
||||
@ -2013,6 +2040,14 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PERSIMMON:
|
||||
{
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
|
||||
switch (hparams.n_layer) {
|
||||
case 36: model.type = e_model::MODEL_8B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
}
|
||||
case LLM_ARCH_REFACT:
|
||||
{
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
@ -2549,6 +2584,67 @@ static void llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PERSIMMON:
|
||||
{
|
||||
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
{
|
||||
ggml_backend backend_norm;
|
||||
ggml_backend backend_output;
|
||||
|
||||
if (n_gpu_layers > int(n_layer)) {
|
||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||
// on Windows however this is detrimental unless everything is on the GPU
|
||||
#ifndef _WIN32
|
||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
||||
#else
|
||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
#endif // _WIN32
|
||||
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
} else {
|
||||
backend_norm = GGML_BACKEND_CPU;
|
||||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
|
||||
if (backend_norm == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(model.output_norm);
|
||||
vram_weights += ggml_nbytes(model.output_norm_b);
|
||||
}
|
||||
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(model.output);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
model.layers.resize(n_layer);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
auto & layer = model.layers[i];
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
|
||||
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
|
||||
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||
layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend);
|
||||
layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend);
|
||||
layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend);
|
||||
layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -2658,8 +2754,8 @@ static bool llama_model_load(
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_llama(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
@ -2697,11 +2793,9 @@ static struct ggml_cgraph * llm_build_llama(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
@ -3085,11 +3179,9 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
@ -3486,11 +3578,9 @@ static struct ggml_cgraph * llm_build_refact(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
@ -3840,11 +3930,9 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
@ -4200,11 +4288,9 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
@ -4415,6 +4501,404 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
||||
static struct ggml_cgraph * llm_build_persimmon(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const size_t n_rot = n_embd_head / 2;
|
||||
|
||||
const float freq_base = cparams.rope_freq_base;
|
||||
const float freq_scale = cparams.rope_freq_scale;
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (batch.token) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
|
||||
}
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
} else {
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
}
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
(void) i_gpu_start;
|
||||
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
|
||||
offload_func_t offload_func_kq = llama_nop;
|
||||
offload_func_t offload_func_v = llama_nop;
|
||||
// KQ_scale
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
|
||||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||
offload_func_kq(KQ_mask);
|
||||
ggml_set_name(KQ_mask, "KQ_mask");
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_mask);
|
||||
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
float * data = (float *) KQ_mask->data;
|
||||
memset(data, 0, ggml_nbytes(KQ_mask));
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j];
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
offload_func_kq(KQ_pos);
|
||||
ggml_set_name(KQ_pos, "KQ_pos");
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = batch.pos[i];
|
||||
}
|
||||
}
|
||||
if (do_rope_shift) {
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||
offload_func_kq(K_shift);
|
||||
ggml_set_name(K_shift, "K_shift");
|
||||
ggml_allocr_alloc(lctx.alloc, K_shift);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
int * data = (int *) K_shift->data;
|
||||
for (int i = 0; i < n_ctx; ++i) {
|
||||
data[i] = kv_self.cells[i].delta;
|
||||
}
|
||||
}
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
// we rotate only the first n_rot dimensions.
|
||||
ggml_rope_custom_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_rot, n_head, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
|
||||
),
|
||||
K_shift, n_rot, 2, 0, freq_base, freq_scale);
|
||||
offload_func_kq(tmp);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
}
|
||||
for (int il=0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * residual = inpL;
|
||||
offload_func_t offload_func = llama_nop;
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
offload_func(cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
|
||||
offload_func(cur);
|
||||
ggml_format_name(cur, "input_layernorm_%d", il);
|
||||
}
|
||||
// self attention
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
offload_func_kq(cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
offload_func_kq(cur);
|
||||
|
||||
// split qkv
|
||||
GGML_ASSERT(n_head_kv == n_head);
|
||||
ggml_set_name(cur, format("qkv_%d", il).c_str());
|
||||
struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
|
||||
offload_func_kq(tmpqkv);
|
||||
struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
|
||||
offload_func_kq(tmpqkv_perm);
|
||||
ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il);
|
||||
struct ggml_tensor * tmpq = ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
|
||||
0
|
||||
);
|
||||
offload_func_kq(tmpq);
|
||||
struct ggml_tensor * tmpk = ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
|
||||
);
|
||||
offload_func_kq(tmpk);
|
||||
// Q/K Layernorm
|
||||
tmpq = ggml_norm(ctx0, tmpq, norm_eps);
|
||||
offload_func_kq(tmpq);
|
||||
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
|
||||
offload_func_kq(tmpq);
|
||||
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
|
||||
offload_func_kq(tmpq);
|
||||
|
||||
tmpk = ggml_norm(ctx0, tmpk, norm_eps);
|
||||
offload_func_v(tmpk);
|
||||
tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
|
||||
offload_func_v(tmpk);
|
||||
tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
|
||||
offload_func_v(tmpk);
|
||||
|
||||
// RoPE the first n_rot of q/k, pass the other half, and concat.
|
||||
struct ggml_tensor * qrot = ggml_view_3d(
|
||||
ctx0, tmpq, n_rot, n_head, n_tokens,
|
||||
ggml_element_size(tmpq) * n_embd_head,
|
||||
ggml_element_size(tmpq) * n_embd_head * n_head,
|
||||
0
|
||||
);
|
||||
offload_func_kq(qrot);
|
||||
ggml_format_name(qrot, "qrot_%d", il);
|
||||
struct ggml_tensor * krot = ggml_view_3d(
|
||||
ctx0, tmpk, n_rot, n_head, n_tokens,
|
||||
ggml_element_size(tmpk) * n_embd_head,
|
||||
ggml_element_size(tmpk) * n_embd_head * n_head,
|
||||
0
|
||||
);
|
||||
offload_func_kq(krot);
|
||||
ggml_format_name(krot, "krot_%d", il);
|
||||
|
||||
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
|
||||
struct ggml_tensor * qpass = ggml_view_3d(
|
||||
ctx0, tmpq, n_rot, n_head, n_tokens,
|
||||
ggml_element_size(tmpq) * n_embd_head,
|
||||
ggml_element_size(tmpq) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpq) * n_rot
|
||||
);
|
||||
offload_func_kq(qpass);
|
||||
ggml_format_name(qpass, "qpass_%d", il);
|
||||
struct ggml_tensor * kpass = ggml_view_3d(
|
||||
ctx0, tmpk, n_rot, n_head, n_tokens,
|
||||
ggml_element_size(tmpk) * n_embd_head,
|
||||
ggml_element_size(tmpk) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpk) * n_rot
|
||||
);
|
||||
offload_func_kq(kpass);
|
||||
ggml_format_name(kpass, "kpass_%d", il);
|
||||
|
||||
struct ggml_tensor * qrotated = ggml_rope_custom(
|
||||
ctx0, qrot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
offload_func_kq(qrotated);
|
||||
struct ggml_tensor * krotated = ggml_rope_custom(
|
||||
ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
offload_func_kq(krotated);
|
||||
// ggml currently only supports concatenation on dim=2
|
||||
// so we need to permute qrot, qpass, concat, then permute back.
|
||||
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
|
||||
offload_func_kq(qrotated);
|
||||
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
|
||||
offload_func_kq(krotated);
|
||||
|
||||
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
|
||||
offload_func_kq(qpass);
|
||||
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
|
||||
offload_func_kq(kpass);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
|
||||
offload_func_kq(Qcur);
|
||||
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
|
||||
offload_func_kq(Kcur);
|
||||
|
||||
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
|
||||
offload_func_kq(Q);
|
||||
|
||||
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
|
||||
offload_func_kq(Kcur);
|
||||
{
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
|
||||
);
|
||||
offload_func_v(tmpv);
|
||||
// store K, V in cache
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
|
||||
offload_func_v(Vcur);
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(
|
||||
ctx0, kv_self.k, n_tokens*n_embd_gqa,
|
||||
(ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)
|
||||
);
|
||||
offload_func_kq(k);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
offload_func_v(v);
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_kv, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||
|
||||
offload_func_kq(K);
|
||||
ggml_format_name(K, "K_%d", il);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
|
||||
offload_func_kq(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
offload_func_kq(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_kv, n_embd_head, n_head_kv,
|
||||
ggml_element_size(kv_self.v)*n_ctx,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
||||
offload_func_v(V);
|
||||
ggml_set_name(V, "V");
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
offload_func_v(KQV);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
offload_func_v(KQV_merged);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
|
||||
offload_func_v(cur);
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur);
|
||||
offload_func(inpFF);
|
||||
ggml_set_name(inpFF, "inpFF");
|
||||
{
|
||||
// MLP
|
||||
{
|
||||
// Norm
|
||||
cur = ggml_norm(ctx0, inpFF, norm_eps);
|
||||
offload_func(cur);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.layers[il].ffn_norm),
|
||||
model.layers[il].ffn_norm_b
|
||||
);
|
||||
ggml_set_name(cur, "ffn_norm");
|
||||
offload_func(cur);
|
||||
}
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
|
||||
offload_func(cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].b3);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_ffn_up");
|
||||
|
||||
cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur));
|
||||
ggml_set_name(cur, "result_ffn_act");
|
||||
offload_func(cur);
|
||||
offload_func(cur->src[0]);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_add(ctx0,
|
||||
cur,
|
||||
model.layers[il].b2);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "outFF");
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "inpFF_+_outFF");
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
offload_func_nr(cur);
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
offload_func_nr(cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.output_norm_b);
|
||||
// offload_func_nr(cur);
|
||||
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ggml_free(ctx0);
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
@ -4439,6 +4923,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm_build_starcoder(lctx, batch);
|
||||
} break;
|
||||
case LLM_ARCH_PERSIMMON:
|
||||
{
|
||||
result = llm_build_persimmon(lctx, batch);
|
||||
}
|
||||
case LLM_ARCH_REFACT:
|
||||
{
|
||||
result = llm_build_refact(lctx, batch);
|
||||
|
Loading…
Reference in New Issue
Block a user