mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-23 21:17:54 +01:00
llm : add MPT support (#3417)
* CUDA: added support for ggml_clamp (see also: https://github.com/ggerganov/ggml/issues/545) * mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt * mpt : protect against "clip_qkv": null in mpt-7b * mpt : quick fix to avoid "Strange model" warning when quantizing MPT models * mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?) * mpt : standardized all tensor names to follow GGUF spec * mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code * mpt : fixed comment s/gptneox/mpt/ * mpt : remove tabs, trailing whitespace * mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt * mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252 * comment out n_past instead of marking it unused * mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"] * mpt : remove unused tokenizer_json in convert script * ggml : remove obsolete n_past assert in ggml_alibi * llama : print clam_kqv and max_alibi_bias hparams --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
11ea5c7d96
commit
f5f9121de1
216
convert-mpt-hf-to-gguf.py
Executable file
216
convert-mpt-hf-to-gguf.py
Executable file
@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python3
|
||||
# HF mpt--> gguf conversion
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer # type: ignore[import]
|
||||
|
||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||
import gguf
|
||||
|
||||
|
||||
def count_model_parts(dir_model: Path) -> int:
|
||||
num_parts = 0
|
||||
for filename in os.listdir(dir_model):
|
||||
if filename.startswith("pytorch_model-"):
|
||||
num_parts += 1
|
||||
|
||||
if num_parts > 0:
|
||||
print("gguf: found " + str(num_parts) + " model parts")
|
||||
return num_parts
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
|
||||
parser.add_argument(
|
||||
"--vocab-only", action="store_true",
|
||||
help="extract only the vocab",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path,
|
||||
help="path to write to; default: based on input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model", type=Path,
|
||||
help="directory containing model file, or model file itself (*.bin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
|
||||
help="output format - use 0 for float32, 1 for float16",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
dir_model = args.model
|
||||
ftype = args.ftype
|
||||
if not dir_model.is_dir():
|
||||
print(f'Error: {args.model} is not a directory', file = sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# possible tensor data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
# output in the same directory as the model by default
|
||||
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
|
||||
|
||||
print("gguf: loading model "+dir_model.name)
|
||||
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
|
||||
if hparams["architectures"][0] != "MPTForCausalLM":
|
||||
print("Model architecture not supported: " + hparams["architectures"][0])
|
||||
|
||||
sys.exit()
|
||||
|
||||
# get number of model parts
|
||||
num_parts = count_model_parts(dir_model)
|
||||
|
||||
ARCH=gguf.MODEL_ARCH.MPT
|
||||
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
||||
|
||||
print("gguf: get model metadata")
|
||||
|
||||
block_count = hparams["n_layers"]
|
||||
|
||||
gguf_writer.add_name(dir_model.name)
|
||||
gguf_writer.add_context_length(hparams["max_seq_len"])
|
||||
gguf_writer.add_embedding_length(hparams["d_model"])
|
||||
gguf_writer.add_block_count(block_count)
|
||||
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
|
||||
gguf_writer.add_head_count(hparams["n_heads"])
|
||||
gguf_writer.add_layer_norm_eps(1e-05)
|
||||
if hparams["attn_config"]["clip_qkv"] is not None:
|
||||
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
|
||||
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
|
||||
|
||||
# TOKENIZATION
|
||||
|
||||
print("gguf: get tokenizer metadata")
|
||||
|
||||
tokens: list[bytearray] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
# gpt2 tokenizer
|
||||
gguf_writer.add_tokenizer_model("gpt2")
|
||||
|
||||
print("gguf: get gpt2 tokenizer vocab")
|
||||
|
||||
# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
|
||||
# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
|
||||
# accomodate some "reserved" tokens; this is causing problems down the line in
|
||||
# llama.cpp, so we pad the vocab with dummy tokens:
|
||||
|
||||
vocab_size = hparams["vocab_size"]
|
||||
|
||||
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
|
||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
|
||||
scores.append(0.0) # dummy
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
|
||||
special_vocab.add_to_gguf(gguf_writer)
|
||||
|
||||
# TENSORS
|
||||
|
||||
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
||||
|
||||
# tensor info
|
||||
print("gguf: get tensor metadata")
|
||||
|
||||
if num_parts == 0:
|
||||
part_names = iter(("pytorch_model.bin",))
|
||||
else:
|
||||
part_names = (
|
||||
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
||||
)
|
||||
|
||||
for part_name in part_names:
|
||||
if args.vocab_only:
|
||||
break
|
||||
print("gguf: loading model part '" + part_name + "'")
|
||||
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
|
||||
|
||||
for name in model_part.keys():
|
||||
data = model_part[name]
|
||||
|
||||
old_dtype = data.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data.dtype != torch.float16 and data.dtype != torch.float32:
|
||||
data = data.to(torch.float32)
|
||||
|
||||
data = data.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print("Cannot map tensor '" + name + "'")
|
||||
continue # for the sake of compatibility with some old published models, don't quit
|
||||
sys.exit()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||
|
||||
gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
# note: MPT output is tied to (same as) wte in original model;
|
||||
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
|
||||
if new_name == "token_embd.weight":
|
||||
gguf_writer.add_tensor("output.weight", data)
|
||||
|
||||
print("gguf: write header")
|
||||
gguf_writer.write_header_to_file()
|
||||
print("gguf: write metadata")
|
||||
gguf_writer.write_kv_data_to_file()
|
||||
if not args.vocab_only:
|
||||
print("gguf: write tensors")
|
||||
gguf_writer.write_tensors_to_file()
|
||||
|
||||
gguf_writer.close()
|
||||
|
||||
print(f"gguf: model successfully exported to '{fname_out}'")
|
||||
print("")
|
47
ggml-cuda.cu
47
ggml-cuda.cu
@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CLAMP_BLOCK_SIZE 256
|
||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||
#define CUDA_ALIBI_BLOCK_SIZE 32
|
||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||
@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
|
||||
dst[i] = scale * x[i];
|
||||
}
|
||||
|
||||
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||
@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
}
|
||||
|
||||
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
|
||||
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
GGML_ASSERT(ne01 + n_past == ne00);
|
||||
//GGML_ASSERT(ne01 + n_past == ne00);
|
||||
GGML_ASSERT(n_head == ne02);
|
||||
|
||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||
@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_clamp(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float min = ((float *) dst->op_params)[0];
|
||||
const float max = ((float *) dst->op_params)[1];
|
||||
|
||||
clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
|
||||
const int64_t nrows0 = ggml_nrows(src0);
|
||||
|
||||
@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||
}
|
||||
|
||||
static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
|
||||
}
|
||||
|
||||
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_SCALE:
|
||||
func = ggml_cuda_scale;
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_clamp;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
func = ggml_cuda_cpy;
|
||||
break;
|
||||
|
@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
const int nth = MIN(1024, ne00);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
4
ggml.c
4
ggml.c
@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
||||
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
||||
|
425
llama.cpp
425
llama.cpp
@ -424,6 +424,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
||||
LLM_ARCH_MPT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -1011,6 +1019,9 @@ struct llama_hparams {
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
|
||||
float f_clamp_kqv;
|
||||
float f_max_alibi_bias;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
@ -2060,6 +2071,20 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MPT:
|
||||
{
|
||||
hparams.f_clamp_kqv = 0.0f;
|
||||
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
|
||||
GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
|
||||
GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 48: model.type = e_model::MODEL_30B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@ -2204,6 +2229,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
||||
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
|
||||
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
|
||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||
@ -2649,6 +2676,73 @@ static void llm_load_tensors(
|
||||
layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MPT:
|
||||
{
|
||||
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
// output
|
||||
{
|
||||
ggml_backend_type backend_norm;
|
||||
ggml_backend_type backend_output;
|
||||
|
||||
if (n_gpu_layers > int(n_layer)) {
|
||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||
// on Windows however this is detrimental unless everything is on the GPU
|
||||
#ifndef _WIN32
|
||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
||||
#else
|
||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
#endif // _WIN32
|
||||
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
} else {
|
||||
backend_norm = GGML_BACKEND_CPU;
|
||||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
|
||||
if (backend_norm == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(model.output_norm);
|
||||
}
|
||||
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(model.output);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attn_norm) +
|
||||
ggml_nbytes(layer.wqkv) +
|
||||
ggml_nbytes(layer.wo) +
|
||||
ggml_nbytes(layer.ffn_norm) +
|
||||
ggml_nbytes(layer.w2) +
|
||||
ggml_nbytes(layer.w3);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -4505,7 +4599,6 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
||||
static struct ggml_cgraph * llm_build_persimmon(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
@ -4903,6 +4996,326 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_mpt(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
const float clamp_kqv = hparams.f_clamp_kqv;
|
||||
const float max_alibi_bias = hparams.f_max_alibi_bias;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
//printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
|
||||
// kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
//int warmup = 0;
|
||||
if (batch.token) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
|
||||
//warmup = ((uint32_t*) inp_tokens->data)[0] == 0;
|
||||
}
|
||||
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
} else {
|
||||
#ifdef GGML_USE_MPI
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
#endif
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
}
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
(void) i_gpu_start;
|
||||
|
||||
// offload functions set the tensor output backend to GPU
|
||||
// tensors are GPU-accelerated if any input or the output has been offloaded
|
||||
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
|
||||
offload_func_t offload_func_kq = llama_nop;
|
||||
offload_func_t offload_func_v = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > n_layer) {
|
||||
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 1) {
|
||||
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 2) {
|
||||
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// KQ_scale
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
}
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||
offload_func_kq(KQ_mask);
|
||||
ggml_set_name(KQ_mask, "KQ_mask");
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_mask);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
float * data = (float *) KQ_mask->data;
|
||||
memset(data, 0, ggml_nbytes(KQ_mask));
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
offload_func_t offload_func = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (il >= i_gpu_start) {
|
||||
offload_func = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// self-attention
|
||||
// TODO: refactor into common function (shared with LLaMA)
|
||||
{
|
||||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
offload_func(attn_norm);
|
||||
|
||||
attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
|
||||
offload_func(attn_norm);
|
||||
|
||||
if (1) {
|
||||
cur = attn_norm;
|
||||
}
|
||||
|
||||
// compute QKV
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
offload_func_kq(cur);
|
||||
|
||||
if (clamp_kqv > 0.0f) {
|
||||
cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
|
||||
offload_func_kq(cur);
|
||||
}
|
||||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0);
|
||||
offload_func_kq(Qcur);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head);
|
||||
offload_func_kq(Kcur);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
offload_func_kq(Kcur);
|
||||
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
|
||||
offload_func_v(Vcur);
|
||||
offload_func_v(Vcur->src[0]->src[0]);
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
offload_func_kq(k);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
offload_func_v(v);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
offload_func_kq(Q);
|
||||
ggml_set_name(Q, "Q");
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_kv, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||
offload_func_kq(K);
|
||||
ggml_set_name(K, "K");
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
// TODO: replace with ggml_add()
|
||||
struct ggml_tensor * KQ_scaled_alibi =
|
||||
ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias);
|
||||
offload_func_kq(KQ_scaled_alibi);
|
||||
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
|
||||
offload_func_kq(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
offload_func_v(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_kv, n_embd_head, n_head_kv,
|
||||
ggml_element_size(kv_self.v)*n_ctx,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
||||
offload_func_v(V);
|
||||
ggml_set_name(V, "V");
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
offload_func_v(KQV);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
offload_func_v(KQV_merged);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
|
||||
offload_func_v(cur);
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
// Add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
offload_func(cur);
|
||||
|
||||
struct ggml_tensor * attn_out = cur;
|
||||
|
||||
// feed forward
|
||||
{
|
||||
// Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, attn_out, norm_eps);
|
||||
offload_func(cur);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
offload_func(cur);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
|
||||
offload_func(cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
offload_func(cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
offload_func(cur);
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
offload_func_nr(cur);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch) {
|
||||
@ -4935,6 +5348,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm_build_refact(lctx, batch);
|
||||
} break;
|
||||
case LLM_ARCH_MPT:
|
||||
{
|
||||
result = llm_build_mpt(lctx, batch);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -5065,7 +5482,8 @@ static int llama_decode_internal(
|
||||
const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
|
||||
model.arch == LLM_ARCH_BAICHUAN ||
|
||||
model.arch == LLM_ARCH_FALCON ||
|
||||
model.arch == LLM_ARCH_REFACT;
|
||||
model.arch == LLM_ARCH_REFACT ||
|
||||
model.arch == LLM_ARCH_MPT;
|
||||
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
||||
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
||||
n_threads = 1;
|
||||
@ -7161,7 +7579,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
const std::string name = ggml_get_name(meta);
|
||||
|
||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||
if (name.find("attn_v.weight") != std::string::npos) {
|
||||
if (name.find("attn_v.weight") != std::string::npos ||
|
||||
name.find("attn_qkv.weight") != std::string::npos) {
|
||||
++n_attention_wv;
|
||||
}
|
||||
else if (name.find("ffn_down.weight") != std::string::npos) {
|
||||
|
Loading…
Reference in New Issue
Block a user