mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-04 23:52:32 +01:00
ecebbd292d
ggml-ci
12508 lines
534 KiB
C++
12508 lines
534 KiB
C++
#include "llama-impl.h"
|
|
|
|
#include "llama-chat.h"
|
|
#include "llama-mmap.h"
|
|
#include "llama-context.h"
|
|
#include "llama-vocab.h"
|
|
#include "llama-sampling.h"
|
|
#include "llama-kv-cache.h"
|
|
#include "llama-model-loader.h"
|
|
#include "llama-model.h"
|
|
|
|
#include "ggml.h"
|
|
#include "ggml-alloc.h"
|
|
#include "ggml-backend.h"
|
|
#include "ggml-cpp.h"
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <cassert>
|
|
#include <cfloat>
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <ctime>
|
|
#include <functional>
|
|
#include <initializer_list>
|
|
#include <map>
|
|
|
|
#if defined(_MSC_VER)
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
#endif
|
|
|
|
//
|
|
// tensor loading (TODO: add llama_tesor_loader?)
|
|
//
|
|
|
|
static int llama_get_device_count(const llama_model & model) {
|
|
return (int) model.devices.size();
|
|
}
|
|
|
|
// checks if the weight tensor can be used with the specified buffer type and device
|
|
static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
|
GGML_ASSERT(w != nullptr);
|
|
|
|
if (op == GGML_OP_NONE) {
|
|
return true;
|
|
}
|
|
|
|
ggml_init_params params = {
|
|
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
|
/*.mem_buffer =*/ NULL,
|
|
/*.no_alloc =*/ true,
|
|
};
|
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
if (!ctx_ptr) {
|
|
throw std::runtime_error(format("failed to create ggml context"));
|
|
}
|
|
ggml_context * ctx = ctx_ptr.get();
|
|
|
|
ggml_tensor * op_tensor = nullptr;
|
|
|
|
switch (op) {
|
|
case GGML_OP_GET_ROWS:
|
|
{
|
|
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
|
op_tensor = ggml_get_rows(ctx, w, b);
|
|
} break;
|
|
case GGML_OP_MUL_MAT:
|
|
{
|
|
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
|
|
op_tensor = ggml_mul_mat(ctx, w, b);
|
|
} break;
|
|
case GGML_OP_MUL_MAT_ID:
|
|
{
|
|
int n_expert_used = hparams.n_expert_used;
|
|
ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
|
|
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
|
|
op_tensor = ggml_mul_mat_id(ctx, w, b, ids);
|
|
} break;
|
|
case GGML_OP_ADD:
|
|
{
|
|
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
|
op_tensor = ggml_add(ctx, a, w);
|
|
} break;
|
|
case GGML_OP_MUL:
|
|
{
|
|
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
|
op_tensor = ggml_mul(ctx, a, w);
|
|
} break;
|
|
case GGML_OP_DIV:
|
|
{
|
|
ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]);
|
|
op_tensor = ggml_div(ctx, a, w);
|
|
} break;
|
|
case GGML_OP_ROPE:
|
|
{
|
|
int n_embd_head = hparams.n_embd_head_v;
|
|
int n_head = hparams.n_head();
|
|
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
|
|
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
|
op_tensor = ggml_rope_ext(
|
|
ctx, a, b, w,
|
|
0, 0, 0, 0, 0,
|
|
0, 0, 0, 0
|
|
);
|
|
|
|
} break;
|
|
case GGML_OP_SSM_CONV:
|
|
{
|
|
// FIXME
|
|
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
|
|
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
|
|
} break;
|
|
case GGML_OP_SSM_SCAN:
|
|
{
|
|
// FIXME
|
|
const int64_t d_state = w->ne[0];
|
|
const int64_t d_inner = w->ne[1];
|
|
const int64_t n_seq_tokens = 512;
|
|
const int64_t n_seqs = 1;
|
|
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
|
|
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
|
|
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
|
|
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
|
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
|
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
|
} break;
|
|
case GGML_OP_RWKV_WKV6:
|
|
{
|
|
// FIXME
|
|
const int64_t S = 123;
|
|
const int64_t H = 123;
|
|
const int64_t n_tokens = 123;
|
|
const int64_t n_seqs = 123;
|
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
|
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
|
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
|
ggml_tensor * tf = w;
|
|
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
|
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
|
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
|
} break;
|
|
case GGML_OP_IM2COL:
|
|
{
|
|
const int n_embd = hparams.n_embd;
|
|
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
|
|
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
|
|
} break;
|
|
default:
|
|
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
|
}
|
|
|
|
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
|
GGML_ASSERT(w->buffer == nullptr);
|
|
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
|
bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
|
ggml_backend_buffer_free(w->buffer);
|
|
w->buffer = nullptr;
|
|
|
|
return op_supported;
|
|
}
|
|
|
|
// find the first buffer type in the list that can use the tensor
|
|
static ggml_backend_buffer_type_t select_weight_buft(const llama_model & model, ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) {
|
|
GGML_ASSERT(!buft_list.empty());
|
|
for (const auto & cur : buft_list) {
|
|
ggml_backend_dev_t cur_dev = cur.first;
|
|
ggml_backend_buffer_type_t cur_buft = cur.second;
|
|
if (weight_buft_supported(model.hparams, tensor, op, cur_buft, cur_dev)) {
|
|
return cur_buft;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
// CPU: ACCEL -> CPU extra -> GPU host -> CPU
|
|
static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
|
|
llama_model::buft_list_t buft_list;
|
|
|
|
// add ACCEL buffer types
|
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
auto * buft = ggml_backend_dev_buffer_type(dev);
|
|
// skip
|
|
if (buft != ggml_backend_cpu_buffer_type()) {
|
|
buft_list.emplace_back(dev, buft);
|
|
}
|
|
}
|
|
}
|
|
|
|
// add extra buffer types
|
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
|
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
|
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
|
if (ggml_backend_dev_get_extra_bufts_fn) {
|
|
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
|
|
while (extra_bufts && *extra_bufts) {
|
|
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
|
++extra_bufts;
|
|
}
|
|
}
|
|
|
|
// add a host buffer type
|
|
// storing the tensors in a host buffer is useful when the processing of large batches
|
|
// is offloaded to a GPU device, since it reduces the time spent on data transfers
|
|
// generally, this will be done using the first device in the list
|
|
// a better approach would be to handle this on a weight-by-weight basis using the offload_op
|
|
// function of the device to determine if it would benefit from being stored in a host buffer
|
|
for (auto * dev : model.devices) {
|
|
ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
|
|
if (buft) {
|
|
buft_list.emplace_back(dev, buft);
|
|
break;
|
|
}
|
|
}
|
|
|
|
// add the CPU buffer type
|
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
|
}
|
|
}
|
|
|
|
return buft_list;
|
|
}
|
|
|
|
// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
|
|
static llama_model::buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
|
|
llama_model::buft_list_t buft_list;
|
|
|
|
// add the device split buffer type if requested and available
|
|
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
|
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
|
auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t)
|
|
ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type");
|
|
if (ggml_backend_split_buffer_type_fn) {
|
|
size_t dev_index = [&]() {
|
|
auto * reg = ggml_backend_dev_backend_reg(dev);
|
|
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) {
|
|
if (ggml_backend_reg_dev_get(reg, i) == dev) {
|
|
return i;
|
|
}
|
|
}
|
|
throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev)));
|
|
}();
|
|
auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split);
|
|
if (buft != nullptr) {
|
|
buft_list.emplace_back(dev, buft);
|
|
}
|
|
}
|
|
}
|
|
|
|
// add the device default buffer type
|
|
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
|
|
|
return buft_list;
|
|
}
|
|
|
|
// Returns false if cancelled by progress_callback
|
|
static bool llm_load_tensors(
|
|
llama_model_loader & ml,
|
|
llama_model & model,
|
|
int n_gpu_layers,
|
|
enum llama_split_mode split_mode,
|
|
int main_gpu,
|
|
const float * tensor_split,
|
|
bool use_mlock,
|
|
llama_progress_callback progress_callback,
|
|
void * progress_callback_user_data) {
|
|
auto & hparams = model.hparams;
|
|
|
|
model.split_mode = split_mode;
|
|
model.main_gpu = main_gpu;
|
|
model.n_gpu_layers = n_gpu_layers;
|
|
|
|
const int n_layer = hparams.n_layer;
|
|
|
|
bool use_mmap_buffer = true;
|
|
|
|
// build a list of buffer types for the CPU and GPU devices
|
|
model.cpu_buft_list = make_cpu_buft_list(model);
|
|
for (auto * dev : model.devices) {
|
|
llama_model::buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
|
|
// add CPU buffer types as a fallback
|
|
buft_list.insert(buft_list.end(), model.cpu_buft_list.begin(), model.cpu_buft_list.end());
|
|
model.gpu_buft_list.emplace(dev, std::move(buft_list));
|
|
}
|
|
|
|
// calculate the split points
|
|
int device_count = llama_get_device_count(model);
|
|
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
|
std::vector<float> splits(device_count);
|
|
if (all_zero) {
|
|
// default split, by free memory
|
|
for (int i = 0; i < device_count; ++i) {
|
|
ggml_backend_dev_t dev = model.devices[i];
|
|
size_t total;
|
|
size_t free;
|
|
ggml_backend_dev_memory(dev, &free, &total);
|
|
splits[i] = free;
|
|
}
|
|
} else {
|
|
std::copy(tensor_split, tensor_split + device_count, splits.begin());
|
|
}
|
|
|
|
// sum and normalize the splits to get the split points
|
|
float split_sum = 0.0f;
|
|
for (int i = 0; i < device_count; ++i) {
|
|
split_sum += splits[i];
|
|
splits[i] = split_sum;
|
|
}
|
|
for (int i = 0; i < device_count; ++i) {
|
|
splits[i] /= split_sum;
|
|
}
|
|
|
|
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
|
const int act_gpu_layers = model.devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
|
|
auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev {
|
|
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
|
|
return {cpu_dev, &model.cpu_buft_list};
|
|
}
|
|
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
|
|
auto * dev = model.devices.at(layer_gpu);
|
|
return {dev, &model.gpu_buft_list.at(dev)};
|
|
};
|
|
|
|
// assign the input layer
|
|
// there is very little benefit to offloading the input layer, so always keep it on the CPU
|
|
model.dev_input = { cpu_dev, &model.cpu_buft_list };
|
|
|
|
// assign the repeating layers to the devices according to the splits
|
|
model.dev_layer.resize(n_layer);
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
model.dev_layer[il] = get_layer_buft_list(il);
|
|
}
|
|
// assign the output layer
|
|
model.dev_output = get_layer_buft_list(n_layer);
|
|
|
|
// one ggml context per buffer type
|
|
int max_n_tensors = ml.n_tensors;
|
|
max_n_tensors += 1; // duplicated output tensor
|
|
max_n_tensors += n_layer*2; // duplicated rope freq tensors
|
|
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
|
|
|
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
|
auto it = ctx_map.find(buft);
|
|
if (it == ctx_map.end()) {
|
|
ggml_init_params params = {
|
|
/*.mem_size =*/ ctx_size,
|
|
/*.mem_buffer =*/ NULL,
|
|
/*.no_alloc =*/ true,
|
|
};
|
|
ggml_context * ctx = ggml_init(params);
|
|
if (!ctx) {
|
|
throw std::runtime_error(format("failed to create ggml context"));
|
|
}
|
|
ctx_map[buft] = ctx;
|
|
model.ctxs.emplace_back(ctx);
|
|
return ctx;
|
|
}
|
|
return it->second;
|
|
};
|
|
|
|
// create tensors for the weights
|
|
{
|
|
// note: cast to int64_t since we will use these for the tensor dimensions
|
|
const int64_t n_head = hparams.n_head();
|
|
const int64_t n_head_kv = hparams.n_head_kv();
|
|
const int64_t n_embd = hparams.n_embd;
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
|
const int64_t n_ff = hparams.n_ff();
|
|
const int64_t n_embd_gqa = n_embd_v_gqa;
|
|
const int64_t n_vocab = hparams.n_vocab;
|
|
const int64_t n_vocab_type = hparams.n_vocab_type;
|
|
const int64_t n_rot = hparams.n_rot;
|
|
const int64_t n_expert = hparams.n_expert;
|
|
const int64_t n_expert_used = hparams.n_expert_used;
|
|
const int64_t n_ctx_train = hparams.n_ctx_train;
|
|
|
|
if (n_expert > 0 && hparams.n_expert_used == 0) {
|
|
throw std::runtime_error("model has expert layers but no expert layers are used");
|
|
}
|
|
|
|
int n_moved_tensors = 0;
|
|
ggml_tensor * first_moved_tensor = nullptr;
|
|
ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
|
|
ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
|
|
|
|
auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * {
|
|
ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
|
|
|
|
if (!t_meta) {
|
|
if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
|
|
return nullptr;
|
|
}
|
|
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
|
|
}
|
|
|
|
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
|
|
// the tensor is duplicated
|
|
// to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
|
|
llm_tensor tn_tensor = tn.tensor;
|
|
if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & llama_model_loader::TENSOR_DUPLICATED) {
|
|
tn_tensor = LLM_TENSOR_OUTPUT;
|
|
}
|
|
|
|
llm_tensor_info info;
|
|
try {
|
|
info = llm_tensor_info_for(tn_tensor);
|
|
} catch (const std::out_of_range & e) {
|
|
throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
|
|
}
|
|
|
|
// tensors with "bias" suffix are always used with GGML_OP_ADD
|
|
ggml_op op;
|
|
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
|
if (bias) {
|
|
op = GGML_OP_ADD;
|
|
} else {
|
|
op = info.op;
|
|
}
|
|
|
|
// sanity checks
|
|
if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
|
|
if (tn.bid != -1) {
|
|
GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
|
|
}
|
|
} else {
|
|
if (tn.bid == -1) {
|
|
GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
|
|
}
|
|
}
|
|
|
|
// select the buffer type for this tensor
|
|
llama_model::buft_list_t * buft_list;
|
|
switch (info.layer) {
|
|
case LLM_TENSOR_LAYER_INPUT:
|
|
buft_list = model.dev_input.buft_list;
|
|
break;
|
|
case LLM_TENSOR_LAYER_OUTPUT:
|
|
buft_list = model.dev_output.buft_list;
|
|
break;
|
|
case LLM_TENSOR_LAYER_REPEATING:
|
|
buft_list = model.dev_layer.at(tn.bid).buft_list;
|
|
break;
|
|
default:
|
|
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
|
|
}
|
|
|
|
ggml_backend_buffer_type_t buft = select_weight_buft(model, t_meta, op, *buft_list);
|
|
if (!buft) {
|
|
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
|
|
}
|
|
|
|
// avoid using a host buffer when using mmap
|
|
auto * buft_dev = ggml_backend_buft_get_device(buft);
|
|
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
|
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
|
}
|
|
|
|
if (buft != buft_list->front().second) {
|
|
n_moved_tensors++;
|
|
if (!first_moved_tensor) {
|
|
first_moved_tensor = t_meta;
|
|
first_moved_from_buft = buft_list->front().second;
|
|
first_moved_to_buft = buft;
|
|
}
|
|
}
|
|
|
|
ggml_context * ctx = ctx_for_buft(buft);
|
|
|
|
// if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
|
|
if (flags & llama_model_loader::TENSOR_DUPLICATED) {
|
|
ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
|
|
if (t) {
|
|
return t;
|
|
}
|
|
}
|
|
return ml.create_tensor(ctx, tn, ne, flags);
|
|
};
|
|
|
|
model.layers.resize(n_layer);
|
|
|
|
// TODO: move to a separate function
|
|
const auto tn = LLM_TN(model.arch);
|
|
switch (model.arch) {
|
|
case LLM_ARCH_LLAMA:
|
|
case LLM_ARCH_REFACT:
|
|
case LLM_ARCH_MINICPM:
|
|
case LLM_ARCH_GRANITE:
|
|
case LLM_ARCH_GRANITE_MOE:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
else {
|
|
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
|
|
if (n_expert == 0) {
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
// optional MLP bias
|
|
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
} else {
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
}
|
|
}
|
|
} break;
|
|
case LLM_ARCH_DECI:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i);
|
|
const int64_t n_ff = hparams.n_ff(i);
|
|
const int64_t n_head = hparams.n_head(i);
|
|
const int64_t n_head_kv = hparams.n_head_kv(i);
|
|
|
|
if (n_head_kv == 0 && n_head > 0) {
|
|
// linear attention for DeciLMCausalModel
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
}
|
|
else if (n_head_kv > 0) {
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
}
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
else {
|
|
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
// optional MLP bias
|
|
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_MINICPM3:
|
|
{
|
|
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
|
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
|
|
|
const int64_t q_lora_rank = hparams.n_lora_q;
|
|
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
|
|
|
|
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
|
|
|
|
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
|
|
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
|
|
|
|
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
|
|
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
} break;
|
|
case LLM_ARCH_GROK:
|
|
{
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("Grok model cannot have zero experts");
|
|
}
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
|
|
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_DBRX:
|
|
{
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("DBRX model cannot have zero experts");
|
|
}
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_BAICHUAN:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
{
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_FALCON:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
{
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
if (!model.output) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_STARCODER:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
|
|
|
// output
|
|
{
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
if (!model.output) {
|
|
// needs to be on GPU
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_BERT:
|
|
case LLM_ARCH_NOMIC_BERT:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
|
|
|
|
if (model.arch == LLM_ARCH_BERT) {
|
|
model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
|
|
|
model.cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
model.cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
model.cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
}
|
|
|
|
model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
if (model.arch == LLM_ARCH_BERT) {
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
} else {
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
}
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
|
|
if (model.arch == LLM_ARCH_BERT) {
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
} else {
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
|
|
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_JINA_BERT_V2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
|
|
model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
|
|
|
|
model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
|
|
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias
|
|
|
|
model.cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i]; // JinaBertLayer
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens
|
|
|
|
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
|
|
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
|
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_BLOOM:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_MPT:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
if (!model.output) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// AWQ ScaleActivation layer
|
|
layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_STABLELM:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
// optional bias tensors, present in Stable LM 2 1.6B
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// optional q and k layernorms, present in StableLM 2 12B
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// optional FFN norm, not present in StableLM 2 12B which uses parallel residual
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_QWEN:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_QWEN2:
|
|
case LLM_ARCH_QWEN2VL:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_QWEN2MOE:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
|
|
}
|
|
if (n_expert_used == 0) {
|
|
throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
|
|
}
|
|
|
|
// MoE branch
|
|
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
|
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
|
|
// Shared expert branch
|
|
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
|
|
|
|
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_PHI2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
if (layer.wqkv == nullptr) {
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
}
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_PHI3:
|
|
{
|
|
const int64_t n_embd_head = n_embd / n_head;
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
|
|
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
}
|
|
} break;
|
|
case LLM_ARCH_PLAMO:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_GPT2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_CODESHELL:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_ORION:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_INTERNLM2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
// layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_GEMMA:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_GEMMA2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_STARCODER2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_MAMBA:
|
|
{
|
|
const int64_t d_conv = hparams.ssm_d_conv;
|
|
const int64_t d_inner = hparams.ssm_d_inner;
|
|
const int64_t d_state = hparams.ssm_d_state;
|
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
|
|
|
// only an expansion factor of 2 is supported for now
|
|
if (2 * n_embd != d_inner) {
|
|
throw std::runtime_error("only an expansion factor of 2 is supported for now");
|
|
}
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed, duplicated to allow offloading
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
// norm
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
|
|
|
|
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
|
|
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
|
|
|
|
layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
|
|
|
|
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
|
|
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
|
|
|
|
// no "weight" suffix for these
|
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
|
|
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
|
|
|
|
// out_proj
|
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_XVERSE:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_COMMAND_R:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
// init output from the input tok embed
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
if (n_layer >= 64){
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
|
|
}
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_COHERE2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
|
// init output from the input tok embed
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
|
|
llama_model_loader::TENSOR_DUPLICATED);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
|
}
|
|
}
|
|
break;
|
|
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_OLMO2:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_OLMOE:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("n_expert must be > 0");
|
|
}
|
|
if (n_expert_used == 0) {
|
|
throw std::runtime_error("n_expert_used must be > 0");
|
|
}
|
|
|
|
// MoE branch
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_OPENELM:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
// init output from the input tok embed
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
const int64_t n_head = hparams.n_head(i);
|
|
const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head;
|
|
const int64_t n_ff = hparams.n_ff(i);
|
|
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_GPTNEOX:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_ARCTIC:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_DEEPSEEK:
|
|
{
|
|
|
|
const int64_t n_ff_exp = hparams.n_ff_exp;
|
|
const int64_t n_expert_shared = hparams.n_expert_shared;
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
if (i < (int) hparams.n_layer_dense_lead) {
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
} else {
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("n_expert must be > 0");
|
|
}
|
|
if (n_expert_used == 0) {
|
|
throw std::runtime_error("n_expert_used must be > 0");
|
|
}
|
|
|
|
// MoE branch
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
|
|
// Shared expert branch
|
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
|
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
}
|
|
}
|
|
} break;
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
{
|
|
const bool is_lite = (hparams.n_layer == 27);
|
|
|
|
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
|
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
|
|
|
const int64_t q_lora_rank = hparams.n_lora_q;
|
|
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
|
|
|
const int64_t n_ff_exp = hparams.n_ff_exp;
|
|
const int64_t n_expert_shared = hparams.n_expert_shared;
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
if (!is_lite) {
|
|
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
|
|
}
|
|
|
|
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
|
|
|
|
if (!is_lite) {
|
|
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
|
|
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
|
|
} else {
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
}
|
|
|
|
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
|
|
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
if (i < (int) hparams.n_layer_dense_lead) {
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
} else {
|
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
if (n_expert == 0) {
|
|
throw std::runtime_error("n_expert must be > 0");
|
|
}
|
|
if (n_expert_used == 0) {
|
|
throw std::runtime_error("n_expert_used must be > 0");
|
|
}
|
|
|
|
// MoE branch
|
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
|
|
// Shared expert branch
|
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
|
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
}
|
|
}
|
|
} break;
|
|
case LLM_ARCH_BITNET:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_T5:
|
|
{
|
|
const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
|
|
|
|
layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
|
|
|
|
layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
// this tensor seems to be unused in HF transformers implementation
|
|
layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_T5ENCODER:
|
|
{
|
|
const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
|
|
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
|
|
|
|
layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_JAIS:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_CHATGLM:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
|
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_NEMOTRON:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
// optional bias tensors
|
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
|
// optional MLP bias
|
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_EXAONE:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_RWKV6:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// Block 0, LN0
|
|
model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
|
|
const int time_decay_extra_dim = hparams.time_decay_extra_dim;
|
|
const int head_size = hparams.wkv_head_size;
|
|
const int attn_hidden_size = n_embd;
|
|
const int ffn_size = hparams.n_ff_arr[0];
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
|
|
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0);
|
|
|
|
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
|
|
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
|
|
|
|
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
|
|
|
|
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
|
|
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
|
|
layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
|
|
layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
|
|
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
|
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
|
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
|
layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
|
|
|
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
|
|
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
|
|
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
|
|
|
layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
|
|
layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
|
|
|
|
layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
|
|
layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
|
|
layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
|
|
}
|
|
|
|
} break;
|
|
case LLM_ARCH_CHAMELEON:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
// if output is NULL, init from the input tok embed
|
|
if (model.output == NULL) {
|
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
|
}
|
|
|
|
for (int i = 0; i < n_layer; ++i) {
|
|
auto & layer = model.layers[i];
|
|
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
|
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
|
|
layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
|
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
}
|
|
} break;
|
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
|
{
|
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
|
|
|
|
model.conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
|
|
model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0);
|
|
|
|
// posnet
|
|
{
|
|
const int64_t n_embd = hparams.posnet.n_embd;
|
|
|
|
for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
|
|
auto & layer = model.layers[i].posnet;
|
|
|
|
// posnet:
|
|
//
|
|
// - resnet
|
|
// - resnet
|
|
// - attn
|
|
// - resnet
|
|
// - resnet
|
|
// - norm
|
|
//
|
|
switch (i) {
|
|
case 0:
|
|
case 1:
|
|
case 3:
|
|
case 4:
|
|
{
|
|
layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
|
|
layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
|
|
layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
|
|
layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
|
|
layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0);
|
|
} break;
|
|
case 2:
|
|
{
|
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
|
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0);
|
|
layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0);
|
|
layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0);
|
|
layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0);
|
|
layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0);
|
|
} break;
|
|
case 5:
|
|
{
|
|
layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
|
|
layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0);
|
|
} break;
|
|
default: GGML_ABORT("unknown posnet layer");
|
|
};
|
|
}
|
|
}
|
|
|
|
GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
|
|
|
|
model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
|
|
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0);
|
|
|
|
// convnext
|
|
{
|
|
const int64_t n_embd = hparams.convnext.n_embd;
|
|
|
|
for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
|
|
auto & layer = model.layers[i].convnext;
|
|
|
|
layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0);
|
|
layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0);
|
|
|
|
layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0);
|
|
layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0);
|
|
|
|
layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0);
|
|
layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0);
|
|
|
|
layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0);
|
|
layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0);
|
|
|
|
layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
|
|
}
|
|
|
|
// output
|
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
|
}
|
|
|
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
|
|
model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
|
|
} break;
|
|
default:
|
|
throw std::runtime_error("unknown architecture");
|
|
}
|
|
|
|
if (n_moved_tensors > 0) {
|
|
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
|
|
__func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
|
|
ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
|
|
}
|
|
}
|
|
|
|
ml.done_getting_tensors();
|
|
|
|
ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
|
|
model.mappings.reserve(ml.mappings.size());
|
|
|
|
// create the backend buffers
|
|
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
|
|
ctx_bufs.reserve(ctx_map.size());
|
|
|
|
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
|
|
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
|
|
model.bufs.reserve(n_max_backend_buffer);
|
|
|
|
for (auto & it : ctx_map) {
|
|
ggml_backend_buffer_type_t buft = it.first;
|
|
ggml_context * ctx = it.second;
|
|
|
|
// skip contexts without tensors
|
|
if (ggml_get_first_tensor(ctx) == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
llama_buf_map bufs;
|
|
bufs.reserve(n_max_backend_buffer);
|
|
|
|
// check if it is possible to use buffer_from_host_ptr with this buffer type
|
|
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
|
|
if (!dev) {
|
|
// FIXME: workaround for CPU backend buft having a NULL device
|
|
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
}
|
|
ggml_backend_dev_props props;
|
|
ggml_backend_dev_get_props(dev, &props);
|
|
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
|
|
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
|
|
|
|
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
|
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
|
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
|
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
|
|
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
|
|
void * addr = nullptr;
|
|
size_t first, last; // NOLINT
|
|
ml.get_mapping_range(&first, &last, &addr, idx, ctx);
|
|
if (first >= last) {
|
|
continue;
|
|
}
|
|
const size_t max_size = ggml_get_max_tensor_size(ctx);
|
|
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
|
if (buf == nullptr) {
|
|
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
|
}
|
|
model.bufs.emplace_back(buf);
|
|
bufs.emplace(idx, buf);
|
|
}
|
|
}
|
|
else {
|
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
|
if (buf == nullptr) {
|
|
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
|
}
|
|
model.bufs.emplace_back(buf);
|
|
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
|
model.mlock_bufs.emplace_back(new llama_mlock);
|
|
auto & mlock_buf = model.mlock_bufs.back();
|
|
mlock_buf->init (ggml_backend_buffer_get_base(buf));
|
|
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
|
|
}
|
|
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
|
bufs.emplace(idx, buf);
|
|
}
|
|
}
|
|
|
|
if (bufs.empty()) {
|
|
throw std::runtime_error("failed to allocate buffer");
|
|
}
|
|
|
|
for (auto & buf : bufs) {
|
|
// indicate that this buffer contains weights
|
|
// this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
|
|
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
}
|
|
|
|
ctx_bufs.emplace_back(ctx, bufs);
|
|
}
|
|
|
|
if (llama_supports_gpu_offload()) {
|
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
|
|
|
LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
|
|
if (n_gpu_layers > (int) hparams.n_layer) {
|
|
LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
|
|
}
|
|
|
|
const int max_backend_supported_layers = hparams.n_layer + 1;
|
|
const int max_offloadable_layers = hparams.n_layer + 1;
|
|
|
|
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
|
}
|
|
|
|
// print memory requirements per buffer type
|
|
for (auto & buf : model.bufs) {
|
|
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
|
}
|
|
|
|
// populate tensors_by_name
|
|
for (auto & ctx : model.ctxs) {
|
|
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
|
model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
|
|
}
|
|
}
|
|
|
|
// load tensor data
|
|
for (auto & it : ctx_bufs) {
|
|
ggml_context * ctx = it.first;
|
|
auto & bufs = it.second;
|
|
if (!ml.load_all_data(ctx, bufs, use_mlock ? &model.mlock_mmaps : NULL, progress_callback, progress_callback_user_data)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (use_mmap_buffer) {
|
|
for (auto & mapping : ml.mappings) {
|
|
model.mappings.emplace_back(std::move(mapping));
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
|
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
|
|
model.t_start_us = ggml_time_us();
|
|
|
|
try {
|
|
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
|
|
|
model.hparams.vocab_only = params.vocab_only;
|
|
|
|
try {
|
|
llm_load_arch(ml, model);
|
|
} catch(const std::exception & e) {
|
|
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
|
}
|
|
try {
|
|
llm_load_hparams(ml, model);
|
|
} catch(const std::exception & e) {
|
|
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
|
}
|
|
try {
|
|
llm_load_vocab(ml, model);
|
|
} catch(const std::exception & e) {
|
|
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
|
}
|
|
|
|
llm_load_stats(ml, model);
|
|
llm_load_print_meta(ml, model);
|
|
|
|
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
|
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
|
throw std::runtime_error("vocab size mismatch");
|
|
}
|
|
|
|
if (params.vocab_only) {
|
|
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
|
return 0;
|
|
}
|
|
|
|
if (!llm_load_tensors(
|
|
ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
|
|
params.progress_callback, params.progress_callback_user_data
|
|
)) {
|
|
return -2;
|
|
}
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
|
return -1;
|
|
}
|
|
|
|
// loading time will be recalculate after the first eval, so
|
|
// we take page faults deferred by mmap() into consideration
|
|
model.t_load_us = ggml_time_us() - model.t_start_us;
|
|
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// llm_build
|
|
//
|
|
|
|
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
|
|
|
|
enum llm_ffn_op_type {
|
|
LLM_FFN_SILU,
|
|
LLM_FFN_GELU,
|
|
LLM_FFN_RELU,
|
|
LLM_FFN_RELU_SQR,
|
|
LLM_FFN_SWIGLU,
|
|
};
|
|
|
|
enum llm_ffn_gate_type {
|
|
LLM_FFN_SEQ,
|
|
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
|
|
};
|
|
|
|
enum llm_norm_type {
|
|
LLM_NORM,
|
|
LLM_NORM_RMS,
|
|
LLM_NORM_GROUP,
|
|
};
|
|
|
|
static struct ggml_tensor * llm_build_inp_embd(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
const llama_hparams & hparams,
|
|
const llama_ubatch & ubatch,
|
|
struct ggml_tensor * tok_embd,
|
|
const llm_build_cb & cb) {
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
|
struct ggml_tensor * inpL;
|
|
|
|
if (ubatch.token) {
|
|
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
|
|
cb(lctx.inp_tokens, "inp_tokens", -1);
|
|
ggml_set_input(lctx.inp_tokens);
|
|
|
|
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
|
} else {
|
|
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
|
inpL = lctx.inp_embd;
|
|
ggml_set_input(lctx.inp_embd);
|
|
}
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_embedding_scale != 0.0f) {
|
|
inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
|
|
}
|
|
|
|
cb(inpL, "inp_embd", -1);
|
|
|
|
return inpL;
|
|
}
|
|
|
|
static void llm_build_kv_store(
|
|
struct ggml_context * ctx,
|
|
const llama_hparams & hparams,
|
|
const llama_cparams & cparams,
|
|
const llama_kv_cache & kv,
|
|
struct ggml_cgraph * graph,
|
|
struct ggml_tensor * k_cur,
|
|
struct ggml_tensor * v_cur,
|
|
int32_t n_tokens,
|
|
int32_t kv_head,
|
|
const llm_build_cb & cb,
|
|
int64_t il) {
|
|
const int64_t n_ctx = cparams.n_ctx;
|
|
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
|
|
GGML_ASSERT(kv.size == n_ctx);
|
|
|
|
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
|
|
cb(k_cache_view, "k_cache_view", il);
|
|
|
|
// note: storing RoPE-ed version of K in the KV cache
|
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
|
|
|
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
|
|
|
struct ggml_tensor * v_cache_view = nullptr;
|
|
|
|
if (cparams.flash_attn) {
|
|
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
|
|
} else {
|
|
// note: the V cache is transposed when not using flash attention
|
|
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
|
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
|
(kv_head)*ggml_element_size(kv.v_l[il]));
|
|
|
|
v_cur = ggml_transpose(ctx, v_cur);
|
|
}
|
|
cb(v_cache_view, "v_cache_view", il);
|
|
|
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
|
}
|
|
|
|
// do mat_mul, while optionally apply lora
|
|
static struct ggml_tensor * llm_build_lora_mm(
|
|
struct llama_context & lctx,
|
|
struct ggml_context * ctx0,
|
|
struct ggml_tensor * w,
|
|
struct ggml_tensor * cur) {
|
|
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
|
for (auto & it : lctx.lora_adapters) {
|
|
struct llama_lora_weight * lora = it.first->get_weight(w);
|
|
if (lora == nullptr) {
|
|
continue;
|
|
}
|
|
const float alpha = it.first->alpha;
|
|
const float rank = (float) lora->b->ne[0];
|
|
const float scale = alpha ? it.second * alpha / rank : it.second;
|
|
struct ggml_tensor * ab_cur = ggml_mul_mat(
|
|
ctx0, lora->b,
|
|
ggml_mul_mat(ctx0, lora->a, cur)
|
|
);
|
|
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
|
res = ggml_add(ctx0, res, ab_cur);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
// do mat_mul_id, while optionally apply lora
|
|
static struct ggml_tensor * llm_build_lora_mm_id(
|
|
struct llama_context & lctx,
|
|
struct ggml_context * ctx0,
|
|
struct ggml_tensor * w, // struct ggml_tensor * as
|
|
struct ggml_tensor * cur, // struct ggml_tensor * b
|
|
struct ggml_tensor * ids) {
|
|
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
|
for (auto & it : lctx.lora_adapters) {
|
|
struct llama_lora_weight * lora = it.first->get_weight(w);
|
|
if (lora == nullptr) {
|
|
continue;
|
|
}
|
|
const float alpha = it.first->alpha;
|
|
const float rank = (float) lora->b->ne[0];
|
|
const float scale = alpha ? it.second * alpha / rank : it.second;
|
|
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
|
|
ctx0, lora->b,
|
|
ggml_mul_mat_id(ctx0, lora->a, cur, ids),
|
|
ids
|
|
);
|
|
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
|
res = ggml_add(ctx0, res, ab_cur);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_norm(
|
|
struct ggml_context * ctx,
|
|
struct ggml_tensor * cur,
|
|
const llama_hparams & hparams,
|
|
struct ggml_tensor * mw,
|
|
struct ggml_tensor * mb,
|
|
llm_norm_type type,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
switch (type) {
|
|
case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
|
|
case LLM_NORM_RMS: cur = ggml_rms_norm (ctx, cur, hparams.f_norm_rms_eps); break;
|
|
case LLM_NORM_GROUP:
|
|
{
|
|
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
|
|
cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
|
|
cur = ggml_reshape_2d(ctx, cur, cur->ne[0], cur->ne[2]);
|
|
} break;
|
|
}
|
|
|
|
if (mw || mb) {
|
|
cb(cur, "norm", il);
|
|
}
|
|
|
|
if (mw) {
|
|
cur = ggml_mul(ctx, cur, mw);
|
|
if (mb) {
|
|
cb(cur, "norm_w", il);
|
|
}
|
|
}
|
|
|
|
if (mb) {
|
|
cur = ggml_add(ctx, cur, mb);
|
|
}
|
|
|
|
return cur;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_ffn(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
struct ggml_tensor * cur,
|
|
struct ggml_tensor * up,
|
|
struct ggml_tensor * up_b,
|
|
struct ggml_tensor * up_s,
|
|
struct ggml_tensor * gate,
|
|
struct ggml_tensor * gate_b,
|
|
struct ggml_tensor * gate_s,
|
|
struct ggml_tensor * down,
|
|
struct ggml_tensor * down_b,
|
|
struct ggml_tensor * down_s,
|
|
struct ggml_tensor * act_scales,
|
|
llm_ffn_op_type type_op,
|
|
llm_ffn_gate_type type_gate,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
|
|
cb(tmp, "ffn_up", il);
|
|
|
|
if (up_b) {
|
|
tmp = ggml_add(ctx, tmp, up_b);
|
|
cb(tmp, "ffn_up_b", il);
|
|
}
|
|
|
|
if (up_s) {
|
|
tmp = ggml_mul(ctx, tmp, up_s);
|
|
cb(tmp, "ffn_up_s", il);
|
|
}
|
|
|
|
if (gate) {
|
|
switch (type_gate) {
|
|
case LLM_FFN_SEQ:
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
|
|
cb(cur, "ffn_gate", il);
|
|
} break;
|
|
case LLM_FFN_PAR:
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx, gate, cur);
|
|
cb(cur, "ffn_gate", il);
|
|
} break;
|
|
}
|
|
|
|
if (gate_b) {
|
|
cur = ggml_add(ctx, cur, gate_b);
|
|
cb(cur, "ffn_gate_b", il);
|
|
}
|
|
|
|
if (gate_s) {
|
|
cur = ggml_mul(ctx, cur, gate_s);
|
|
cb(cur, "ffn_gate_s", il);
|
|
}
|
|
|
|
} else {
|
|
cur = tmp;
|
|
}
|
|
|
|
switch (type_op) {
|
|
case LLM_FFN_SILU:
|
|
{
|
|
cur = ggml_silu(ctx, cur);
|
|
cb(cur, "ffn_silu", il);
|
|
} break;
|
|
case LLM_FFN_GELU:
|
|
{
|
|
cur = ggml_gelu(ctx, cur);
|
|
cb(cur, "ffn_gelu", il);
|
|
if (act_scales != NULL) {
|
|
cur = ggml_div(ctx, cur, act_scales);
|
|
cb(cur, "ffn_act", il);
|
|
}
|
|
} break;
|
|
case LLM_FFN_RELU:
|
|
{
|
|
cur = ggml_relu(ctx, cur);
|
|
cb(cur, "ffn_relu", il);
|
|
} break;
|
|
case LLM_FFN_RELU_SQR:
|
|
{
|
|
cur = ggml_relu(ctx, cur);
|
|
cb(cur, "ffn_relu", il);
|
|
|
|
cur = ggml_sqr(ctx, cur);
|
|
cb(cur, "ffn_sqr(relu)", il);
|
|
} break;
|
|
case LLM_FFN_SWIGLU:
|
|
{
|
|
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
|
int64_t split_point = cur->ne[0] / 2;
|
|
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
|
|
x0 = ggml_silu(ctx, x0);
|
|
cb(cur, "ffn_silu", il);
|
|
|
|
cur = ggml_mul(ctx, x0, x1);
|
|
cb(cur, "ffn_mul", il);
|
|
} break;
|
|
}
|
|
|
|
if (type_gate == LLM_FFN_PAR) {
|
|
cur = ggml_mul(ctx, cur, tmp);
|
|
cb(cur, "ffn_gate_par", il);
|
|
}
|
|
|
|
if (down) {
|
|
cur = llm_build_lora_mm(lctx, ctx, down, cur);
|
|
}
|
|
|
|
if (down_b) {
|
|
cb(cur, "ffn_down", il);
|
|
}
|
|
|
|
if (down_b) {
|
|
cur = ggml_add(ctx, cur, down_b);
|
|
}
|
|
|
|
if (down_s) {
|
|
cur = ggml_mul(ctx, cur, down_s);
|
|
cb(cur, "ffn_down_s", il);
|
|
}
|
|
|
|
return cur;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_moe_ffn(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
struct ggml_tensor * cur,
|
|
struct ggml_tensor * gate_inp,
|
|
struct ggml_tensor * up_exps,
|
|
struct ggml_tensor * gate_exps,
|
|
struct ggml_tensor * down_exps,
|
|
struct ggml_tensor * exp_probs_b,
|
|
int64_t n_expert,
|
|
int64_t n_expert_used,
|
|
llm_ffn_op_type type_op,
|
|
bool norm_w,
|
|
bool scale_w,
|
|
float w_scale,
|
|
llama_expert_gating_func_type gating_op,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
int64_t n_embd = cur->ne[0];
|
|
int64_t n_tokens = cur->ne[1];
|
|
|
|
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
|
|
cb(logits, "ffn_moe_logits", il);
|
|
|
|
ggml_tensor * probs = nullptr;
|
|
switch (gating_op) {
|
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
|
{
|
|
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
|
} break;
|
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
|
|
{
|
|
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
|
|
} break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
cb(probs, "ffn_moe_probs", il);
|
|
|
|
// add experts selection bias - introduced in DeepSeek V3
|
|
// leave probs unbiased as it's later used to get expert weights
|
|
ggml_tensor * selection_probs = probs;
|
|
if (exp_probs_b != nullptr) {
|
|
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
|
cb(selection_probs, "ffn_moe_probs_biased", il);
|
|
}
|
|
|
|
// select experts
|
|
ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
cb(selected_experts, "ffn_moe_topk", il);
|
|
|
|
ggml_tensor * weights = ggml_get_rows(ctx,
|
|
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
|
cb(weights, "ffn_moe_weights", il);
|
|
|
|
if (norm_w) {
|
|
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
|
|
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
|
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
|
|
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
|
|
cb(weights, "ffn_moe_weights_norm", il);
|
|
|
|
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
|
}
|
|
if (scale_w) {
|
|
weights = ggml_scale(ctx, weights, w_scale);
|
|
cb(weights, "ffn_moe_weights_scaled", il);
|
|
}
|
|
|
|
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
|
|
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
cb(up, "ffn_moe_up", il);
|
|
|
|
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
cb(gate, "ffn_moe_gate", il);
|
|
|
|
switch (type_op) {
|
|
case LLM_FFN_SILU:
|
|
{
|
|
gate = ggml_silu(ctx, gate);
|
|
cb(gate, "ffn_moe_silu", il);
|
|
} break;
|
|
case LLM_FFN_GELU:
|
|
{
|
|
gate = ggml_gelu(ctx, gate);
|
|
cb(gate, "ffn_moe_gelu", il);
|
|
} break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
|
|
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
|
|
cb(par, "ffn_moe_gate_par", il);
|
|
|
|
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
cb(experts, "ffn_moe_down", il);
|
|
|
|
experts = ggml_mul(ctx, experts, weights);
|
|
|
|
// aggregate experts
|
|
ggml_tensor * moe_out = nullptr;
|
|
for (int i = 0; i < n_expert_used; ++i) {
|
|
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
|
experts->nb[2], i*experts->nb[1]);
|
|
|
|
if (i == 0) {
|
|
moe_out = cur_expert;
|
|
} else {
|
|
moe_out = ggml_add(ctx, moe_out, cur_expert);
|
|
}
|
|
}
|
|
|
|
if (n_expert_used == 1) {
|
|
// avoid returning a non-contiguous tensor
|
|
moe_out = ggml_cont(ctx, moe_out);
|
|
}
|
|
|
|
return moe_out;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_kqv(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
const llama_kv_cache & kv,
|
|
struct ggml_cgraph * graph,
|
|
struct ggml_tensor * wo,
|
|
struct ggml_tensor * wo_b,
|
|
struct ggml_tensor * q_cur,
|
|
struct ggml_tensor * kq_mask,
|
|
int32_t n_tokens,
|
|
int32_t n_kv,
|
|
float kq_scale,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
const llama_model & model = lctx.model;
|
|
const llama_hparams & hparams = lctx.model.hparams;
|
|
const llama_cparams & cparams = lctx.cparams;
|
|
|
|
const int64_t n_ctx = cparams.n_ctx;
|
|
const int64_t n_head = hparams.n_head(il);
|
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
|
cb(q, "q", il);
|
|
|
|
struct ggml_tensor * k =
|
|
ggml_view_3d(ctx, kv.k_l[il],
|
|
n_embd_head_k, n_kv, n_head_kv,
|
|
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
|
|
0);
|
|
cb(k, "k", il);
|
|
|
|
struct ggml_tensor * cur;
|
|
|
|
if (cparams.flash_attn) {
|
|
GGML_UNUSED(model);
|
|
GGML_UNUSED(n_ctx);
|
|
|
|
// split cached v into n_head heads (not transposed)
|
|
struct ggml_tensor * v =
|
|
ggml_view_3d(ctx, kv.v_l[il],
|
|
n_embd_head_v, n_kv, n_head_kv,
|
|
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
|
|
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
|
|
0);
|
|
cb(v, "v", il);
|
|
|
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
|
|
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
|
|
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
|
} else {
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
|
cb(kq, "kq", il);
|
|
|
|
// note: this op tends to require high floating point range
|
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
|
|
if (model.arch == LLM_ARCH_GROK) {
|
|
// need to do the following:
|
|
// multiply by attn_output_multiplyer of 0.08838834764831845
|
|
// and then :
|
|
// kq = 30 * tanh(kq / 30)
|
|
// before the softmax below
|
|
|
|
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
|
|
kq = ggml_scale(ctx, kq, 30);
|
|
}
|
|
|
|
if (hparams.attn_soft_cap) {
|
|
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
|
kq = ggml_tanh(ctx, kq);
|
|
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
|
|
}
|
|
|
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
|
cb(kq, "kq_soft_max_ext", il);
|
|
|
|
GGML_ASSERT(kv.size == n_ctx);
|
|
|
|
// split cached v into n_head heads
|
|
struct ggml_tensor * v =
|
|
ggml_view_3d(ctx, kv.v_l[il],
|
|
n_kv, n_embd_head_v, n_head_kv,
|
|
ggml_element_size(kv.v_l[il])*n_ctx,
|
|
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
|
0);
|
|
cb(v, "v", il);
|
|
|
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
|
cb(kqv, "kqv", il);
|
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
|
cb(cur, "kqv_merged_cont", il);
|
|
}
|
|
|
|
ggml_build_forward_expand(graph, cur);
|
|
|
|
if (wo) {
|
|
cur = llm_build_lora_mm(lctx, ctx, wo, cur);
|
|
}
|
|
|
|
if (wo_b) {
|
|
cb(cur, "kqv_wo", il);
|
|
}
|
|
|
|
if (wo_b) {
|
|
cur = ggml_add(ctx, cur, wo_b);
|
|
}
|
|
|
|
return cur;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_kv(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
const llama_kv_cache & kv,
|
|
struct ggml_cgraph * graph,
|
|
struct ggml_tensor * wo,
|
|
struct ggml_tensor * wo_b,
|
|
struct ggml_tensor * k_cur,
|
|
struct ggml_tensor * v_cur,
|
|
struct ggml_tensor * q_cur,
|
|
struct ggml_tensor * kq_mask,
|
|
int32_t n_tokens,
|
|
int32_t kv_head,
|
|
int32_t n_kv,
|
|
float kq_scale,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
const llama_hparams & hparams = lctx.model.hparams;
|
|
const llama_cparams & cparams = lctx.cparams;
|
|
|
|
// these nodes are added to the graph together so that they are not reordered
|
|
// by doing so, the number of splits in the graph is reduced
|
|
ggml_build_forward_expand(graph, q_cur);
|
|
ggml_build_forward_expand(graph, k_cur);
|
|
ggml_build_forward_expand(graph, v_cur);
|
|
|
|
llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
|
|
|
|
struct ggml_tensor * cur;
|
|
|
|
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
|
|
cb(cur, "kqv_out", il);
|
|
|
|
return cur;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_copy_mask_state(
|
|
struct ggml_context * ctx,
|
|
struct ggml_cgraph * graph,
|
|
struct ggml_tensor * s,
|
|
struct ggml_tensor * state_copy,
|
|
struct ggml_tensor * state_mask,
|
|
int32_t n_state,
|
|
int32_t kv_size,
|
|
int32_t kv_head,
|
|
int32_t n_kv,
|
|
int32_t n_seqs) {
|
|
struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
|
|
|
|
// copy states
|
|
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
// this shrinks the tensors's ne[1] to n_kv
|
|
states = ggml_get_rows(ctx, states, state_copy);
|
|
|
|
// clear states of sequences which are starting at the beginning of this batch
|
|
// FIXME: zero-out NANs?
|
|
states = ggml_mul(ctx, states, state_mask);
|
|
|
|
// copy states which won't be changed further (between n_seqs and n_kv)
|
|
ggml_build_forward_expand(graph,
|
|
ggml_cpy(ctx,
|
|
ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
|
|
ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
|
|
|
|
// the part of the states that will be used and modified
|
|
return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
|
|
}
|
|
|
|
// TODO: split
|
|
static struct ggml_tensor * llm_build_mamba(
|
|
struct ggml_context * ctx,
|
|
struct llama_context & lctx,
|
|
const llama_ubatch & ubatch,
|
|
struct ggml_cgraph * graph,
|
|
struct ggml_tensor * cur,
|
|
struct ggml_tensor * state_copy,
|
|
struct ggml_tensor * state_mask,
|
|
int32_t kv_head,
|
|
int32_t n_kv,
|
|
const llm_build_cb & cb,
|
|
int il) {
|
|
const llama_model & model = lctx.model;
|
|
const llama_hparams & hparams = model.hparams;
|
|
const llama_kv_cache & kv = lctx.kv_self;
|
|
const int64_t d_conv = hparams.ssm_d_conv;
|
|
const int64_t d_inner = hparams.ssm_d_inner;
|
|
const int64_t d_state = hparams.ssm_d_state;
|
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
|
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
|
// Use the same RMS norm as the final layer norm
|
|
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
|
|
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
GGML_ASSERT(n_seqs != 0);
|
|
GGML_ASSERT(ubatch.equal_seqs);
|
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
|
|
|
struct ggml_tensor * conv_states_all = kv.k_l[il];
|
|
struct ggml_tensor * ssm_states_all = kv.v_l[il];
|
|
|
|
// (ab)using the KV cache to store the states
|
|
struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
|
|
graph, conv_states_all, state_copy, state_mask,
|
|
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
|
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
|
|
struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
|
|
graph, ssm_states_all, state_copy, state_mask,
|
|
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
|
|
ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
|
|
|
|
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
|
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
|
|
|
// {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
|
|
struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
|
|
// split the above in two
|
|
// => {d_inner, n_seq_tokens, n_seqs}
|
|
struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
|
|
struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
|
|
|
|
// conv
|
|
{
|
|
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
|
struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
|
|
|
|
// copy last (d_conv - 1) columns back into the state cache
|
|
struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
|
|
|
|
ggml_build_forward_expand(graph,
|
|
ggml_cpy(ctx, last_conv,
|
|
ggml_view_1d(ctx, conv_states_all,
|
|
(d_conv - 1)*(d_inner)*(n_seqs),
|
|
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
|
|
|
|
// 1D convolution
|
|
// The equivalent is to make a self-overlapping view of conv_x
|
|
// over d_conv columns at each stride in the 3rd dimension,
|
|
// then element-wise multiply that with the conv1d weight,
|
|
// then sum the elements of each row,
|
|
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
|
// then permute away the ne[0] dimension,
|
|
// and then you're left with the resulting x tensor.
|
|
// For simultaneous sequences, all sequences need to have the same length.
|
|
x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
|
|
|
|
// bias
|
|
x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
|
|
|
|
x = ggml_silu(ctx, x);
|
|
}
|
|
|
|
// ssm
|
|
{
|
|
// {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
|
|
struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
|
|
// split
|
|
struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
|
|
struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
|
|
struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
|
|
|
|
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
|
if (ssm_dt_b_c_rms) {
|
|
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
|
|
B = ggml_rms_norm(ctx, B, norm_rms_eps);
|
|
C = ggml_rms_norm(ctx, C, norm_rms_eps);
|
|
}
|
|
|
|
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
|
|
dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
|
|
dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
|
|
|
|
// Custom operator to optimize the parallel associative scan
|
|
// as described in the Annex D of the Mamba paper.
|
|
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
|
struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
|
|
|
|
// store last states
|
|
ggml_build_forward_expand(graph,
|
|
ggml_cpy(ctx,
|
|
ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
|
|
ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
|
|
|
struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
|
|
|
|
// TODO: skip computing output earlier for unused tokens
|
|
|
|
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
|
|
y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
|
|
y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
|
|
|
|
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
|
cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
|
|
}
|
|
|
|
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
|
cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
|
cb(cur, "mamba_out", il);
|
|
|
|
return cur;
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|
struct llama_context & lctx,
|
|
struct ggml_context * ctx,
|
|
const struct llama_layer * layer,
|
|
struct ggml_tensor * cur,
|
|
struct ggml_tensor * x_prev,
|
|
struct ggml_tensor ** wkv_state) {
|
|
size_t n_embd = cur->ne[0];
|
|
size_t n_seq_tokens = cur->ne[1];
|
|
size_t n_seqs = cur->ne[2];
|
|
|
|
size_t head_size = layer->time_mix_first->ne[0];
|
|
size_t head_count = layer->time_mix_first->ne[1];
|
|
|
|
size_t n_tokens = n_seqs * n_seq_tokens;
|
|
|
|
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
|
|
|
|
sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens);
|
|
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
|
|
|
|
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
|
|
|
|
xxx = ggml_reshape_4d(
|
|
ctx,
|
|
ggml_tanh(
|
|
ctx,
|
|
ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
|
|
),
|
|
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
|
|
);
|
|
|
|
xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
|
|
|
|
xxx = ggml_mul_mat(
|
|
ctx,
|
|
ggml_reshape_4d(
|
|
ctx,
|
|
layer->time_mix_w2,
|
|
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
|
|
),
|
|
xxx
|
|
);
|
|
|
|
struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
|
struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
|
struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
|
struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
|
struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
|
|
|
struct ggml_tensor * xw = ggml_add(
|
|
ctx,
|
|
ggml_mul(
|
|
ctx,
|
|
ggml_add(ctx, mw, layer->time_mix_lerp_w),
|
|
sx
|
|
),
|
|
cur
|
|
);
|
|
|
|
struct ggml_tensor * xk = ggml_add(
|
|
ctx,
|
|
ggml_mul(
|
|
ctx,
|
|
ggml_add(ctx, mk, layer->time_mix_lerp_k),
|
|
sx
|
|
),
|
|
cur
|
|
);
|
|
|
|
struct ggml_tensor * xv = ggml_add(
|
|
ctx,
|
|
ggml_mul(
|
|
ctx,
|
|
ggml_add(ctx, mv, layer->time_mix_lerp_v),
|
|
sx
|
|
),
|
|
cur
|
|
);
|
|
|
|
struct ggml_tensor * xr = ggml_add(
|
|
ctx,
|
|
ggml_mul(
|
|
ctx,
|
|
ggml_add(ctx, mr, layer->time_mix_lerp_r),
|
|
sx
|
|
),
|
|
cur
|
|
);
|
|
|
|
struct ggml_tensor * xg = ggml_add(
|
|
ctx,
|
|
ggml_mul(
|
|
ctx,
|
|
ggml_add(ctx, mg, layer->time_mix_lerp_g),
|
|
sx
|
|
),
|
|
cur
|
|
);
|
|
|
|
struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
|
|
struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
|
|
struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
|
|
struct ggml_tensor * g = ggml_silu(
|
|
ctx,
|
|
llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
|
|
);
|
|
|
|
struct ggml_tensor * w = ggml_mul_mat(
|
|
ctx,
|
|
layer->time_mix_decay_w2,
|
|
ggml_tanh(
|
|
ctx,
|
|
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
|
|
)
|
|
);
|
|
|
|
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
|
|
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
|
|
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
|
|
|
|
k = ggml_transpose(ctx, k);
|
|
v = ggml_transpose(ctx, v);
|
|
r = ggml_transpose(ctx, r);
|
|
|
|
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
|
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
|
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
|
|
|
// group norm with head_count groups
|
|
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
|
|
cur = ggml_norm(ctx, cur, 64e-5f);
|
|
|
|
// Convert back to regular vectors.
|
|
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
|
|
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
|
|
|
|
cur = ggml_mul(ctx, cur, g);
|
|
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
|
|
|
|
return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
|
|
}
|
|
|
|
static struct ggml_tensor * llm_build_rwkv6_channel_mix(
|
|
struct llama_context & lctx,
|
|
struct ggml_context * ctx,
|
|
const struct llama_layer * layer,
|
|
struct ggml_tensor * cur,
|
|
struct ggml_tensor * x_prev) {
|
|
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
|
|
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
|
|
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
|
|
|
|
struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
|
|
struct ggml_tensor * k = ggml_sqr(
|
|
ctx,
|
|
ggml_relu(
|
|
ctx,
|
|
llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
|
|
)
|
|
);
|
|
|
|
return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
|
|
}
|
|
|
|
struct llm_build_context {
|
|
const llama_model & model;
|
|
llama_context & lctx;
|
|
const llama_hparams & hparams;
|
|
const llama_cparams & cparams;
|
|
const llama_ubatch & ubatch;
|
|
const llama_kv_cache & kv_self;
|
|
|
|
const int64_t n_embd;
|
|
const int64_t n_layer;
|
|
const int64_t n_rot;
|
|
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
|
const int64_t n_head;
|
|
const int64_t n_head_kv;
|
|
const int64_t n_embd_head_k;
|
|
const int64_t n_embd_k_gqa;
|
|
const int64_t n_embd_head_v;
|
|
const int64_t n_embd_v_gqa;
|
|
const int64_t n_expert;
|
|
const int64_t n_expert_used;
|
|
|
|
const float freq_base;
|
|
const float freq_scale;
|
|
const float ext_factor;
|
|
const float attn_factor;
|
|
const float beta_fast;
|
|
const float beta_slow;
|
|
const float norm_eps;
|
|
const float norm_rms_eps;
|
|
|
|
const int32_t n_tokens;
|
|
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
|
const int32_t n_outputs;
|
|
const int32_t n_outputs_enc;
|
|
const int32_t kv_head; // index of where we store new KV data in the cache
|
|
const int32_t n_ctx_orig;
|
|
|
|
const bool flash_attn;
|
|
|
|
const enum llama_pooling_type pooling_type;
|
|
const enum llama_rope_type rope_type;
|
|
|
|
const llm_build_cb & cb;
|
|
|
|
std::vector<uint8_t> & buf_compute_meta;
|
|
|
|
struct ggml_context * ctx0 = nullptr;
|
|
|
|
// TODO: consider making the entire interface noexcept
|
|
llm_build_context(
|
|
llama_context & lctx,
|
|
const llama_ubatch & ubatch,
|
|
const llm_build_cb & cb,
|
|
bool worst_case) :
|
|
model (lctx.model),
|
|
lctx (lctx),
|
|
hparams (model.hparams),
|
|
cparams (lctx.cparams),
|
|
ubatch (ubatch),
|
|
kv_self (lctx.kv_self),
|
|
n_embd (hparams.n_embd),
|
|
n_layer (hparams.n_layer),
|
|
n_rot (hparams.n_rot),
|
|
n_ctx (cparams.n_ctx),
|
|
n_head (hparams.n_head()),
|
|
n_head_kv (hparams.n_head_kv()),
|
|
n_embd_head_k (hparams.n_embd_head_k),
|
|
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
|
n_embd_head_v (hparams.n_embd_head_v),
|
|
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
|
n_expert (hparams.n_expert),
|
|
n_expert_used (hparams.n_expert_used),
|
|
freq_base (cparams.rope_freq_base),
|
|
freq_scale (cparams.rope_freq_scale),
|
|
ext_factor (cparams.yarn_ext_factor),
|
|
attn_factor (cparams.yarn_attn_factor),
|
|
beta_fast (cparams.yarn_beta_fast),
|
|
beta_slow (cparams.yarn_beta_slow),
|
|
norm_eps (hparams.f_norm_eps),
|
|
norm_rms_eps (hparams.f_norm_rms_eps),
|
|
n_tokens (ubatch.n_tokens),
|
|
n_kv (worst_case ? kv_self.size : kv_self.n),
|
|
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
|
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
|
flash_attn (cparams.flash_attn),
|
|
pooling_type (cparams.pooling_type),
|
|
rope_type (hparams.rope_type),
|
|
cb (cb),
|
|
buf_compute_meta (lctx.buf_compute_meta) {
|
|
// all initializations should be done in init()
|
|
}
|
|
|
|
void init() {
|
|
struct ggml_init_params params = {
|
|
/*.mem_size =*/ buf_compute_meta.size(),
|
|
/*.mem_buffer =*/ buf_compute_meta.data(),
|
|
/*.no_alloc =*/ true,
|
|
};
|
|
|
|
ctx0 = ggml_init(params);
|
|
|
|
lctx.inp_tokens = nullptr;
|
|
lctx.inp_embd = nullptr;
|
|
lctx.inp_pos = nullptr;
|
|
lctx.inp_out_ids = nullptr;
|
|
lctx.inp_KQ_mask = nullptr;
|
|
lctx.inp_KQ_mask_swa = nullptr;
|
|
lctx.inp_K_shift = nullptr;
|
|
lctx.inp_mean = nullptr;
|
|
lctx.inp_cls = nullptr;
|
|
lctx.inp_s_copy = nullptr;
|
|
lctx.inp_s_mask = nullptr;
|
|
lctx.inp_s_seq = nullptr;
|
|
lctx.inp_pos_bucket = nullptr;
|
|
lctx.inp_embd_enc = nullptr;
|
|
lctx.inp_KQ_mask_cross = nullptr;
|
|
}
|
|
|
|
void free() {
|
|
ggml_free(ctx0);
|
|
ctx0 = nullptr;
|
|
}
|
|
|
|
struct ggml_cgraph * build_k_shift() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
GGML_ASSERT(kv_self.size == n_ctx);
|
|
|
|
lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
|
cb(lctx.inp_K_shift, "K_shift", -1);
|
|
ggml_set_input(lctx.inp_K_shift);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
struct ggml_tensor * k =
|
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
|
n_embd_head_k, n_head_kv, n_ctx,
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
|
0);
|
|
|
|
struct ggml_tensor * tmp;
|
|
if (ggml_is_quantized(k->type)) {
|
|
// dequantize to f32 -> RoPE -> quantize back
|
|
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
|
|
cb(tmp, "K_f32", il);
|
|
for (auto & backend : lctx.backends) {
|
|
// Figure out which backend KV cache belongs to
|
|
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
|
|
ggml_backend_sched_set_tensor_backend(lctx.sched.get(), tmp, backend.get());
|
|
break;
|
|
}
|
|
}
|
|
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(tmp, "K_shifted_f32", il);
|
|
tmp = ggml_cpy(ctx0, tmp, k);
|
|
} else {
|
|
// we rotate only the first n_rot dimensions
|
|
tmp = ggml_rope_ext_inplace(ctx0, k,
|
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
}
|
|
cb(tmp, "K_shifted", il);
|
|
ggml_build_forward_expand(gf, tmp);
|
|
}
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
for (uint32_t i = 0; i < ids.size(); ++i) {
|
|
const uint32_t id = ids[i];
|
|
|
|
if (i == id || id == ids.size()) {
|
|
continue;
|
|
}
|
|
|
|
uint32_t nm = 1;
|
|
|
|
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
|
nm++;
|
|
}
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
|
|
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
|
n_embd_k_gqa, nm,
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
|
|
|
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
|
n_embd_k_gqa, nm,
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
|
|
|
ggml_tensor * view_v_src;
|
|
ggml_tensor * view_v_dst;
|
|
|
|
if (flash_attn) {
|
|
// NOTE: the V cache is not transposed when using flash attention
|
|
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
|
n_embd_v_gqa, nm,
|
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
|
|
|
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
|
n_embd_v_gqa, nm,
|
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
|
} else {
|
|
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
|
nm, n_embd_v_gqa,
|
|
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
|
ggml_row_size(kv_self.v_l[il]->type, i));
|
|
|
|
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
|
nm, n_embd_v_gqa,
|
|
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
|
ggml_row_size(kv_self.v_l[il]->type, id));
|
|
}
|
|
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
|
}
|
|
|
|
i += nm - 1;
|
|
}
|
|
|
|
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_pos() {
|
|
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
|
cb(lctx.inp_pos, "inp_pos", -1);
|
|
ggml_set_input(lctx.inp_pos);
|
|
return lctx.inp_pos;
|
|
}
|
|
|
|
struct ggml_tensor * build_rope_factors(int il) {
|
|
// choose long/short freq factors based on the context size
|
|
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
|
|
if (model.layers[il].rope_freqs != nullptr) {
|
|
return model.layers[il].rope_freqs;
|
|
}
|
|
|
|
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
|
|
return model.layers[il].rope_long;
|
|
}
|
|
|
|
return model.layers[il].rope_short;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_out_ids() {
|
|
lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
|
|
cb(lctx.inp_out_ids, "inp_out_ids", -1);
|
|
ggml_set_input(lctx.inp_out_ids);
|
|
return lctx.inp_out_ids;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
|
|
lctx.inp_KQ_mask = causal
|
|
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
|
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
|
ggml_set_input(lctx.inp_KQ_mask);
|
|
|
|
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
|
|
GGML_ASSERT(hparams.n_swa > 0);
|
|
|
|
lctx.inp_KQ_mask_swa = causal
|
|
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
|
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
|
|
ggml_set_input(lctx.inp_KQ_mask_swa);
|
|
|
|
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_mean() {
|
|
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
|
cb(lctx.inp_mean, "inp_mean", -1);
|
|
ggml_set_input(lctx.inp_mean);
|
|
return lctx.inp_mean;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_cls() {
|
|
lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
|
cb(lctx.inp_cls, "inp_cls", -1);
|
|
ggml_set_input(lctx.inp_cls);
|
|
return lctx.inp_cls;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_s_copy() {
|
|
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
|
ggml_set_input(lctx.inp_s_copy);
|
|
return lctx.inp_s_copy;
|
|
}
|
|
|
|
struct ggml_tensor * build_inp_s_mask() {
|
|
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
|
cb(lctx.inp_s_mask, "inp_s_mask", -1);
|
|
ggml_set_input(lctx.inp_s_mask);
|
|
return lctx.inp_s_mask;
|
|
}
|
|
|
|
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
|
|
// find result_norm tensor for input
|
|
struct ggml_tensor * inp = nullptr;
|
|
for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
|
|
inp = ggml_graph_node(gf, i);
|
|
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
|
|
break;
|
|
} else {
|
|
inp = nullptr;
|
|
}
|
|
}
|
|
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
|
|
|
|
struct ggml_tensor * cur;
|
|
|
|
switch (pooling_type) {
|
|
case LLAMA_POOLING_TYPE_NONE:
|
|
{
|
|
cur = inp;
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
{
|
|
struct ggml_tensor * inp_mean = build_inp_mean();
|
|
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_CLS:
|
|
case LLAMA_POOLING_TYPE_LAST:
|
|
{
|
|
struct ggml_tensor * inp_cls = build_inp_cls();
|
|
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_RANK:
|
|
{
|
|
struct ggml_tensor * inp_cls = build_inp_cls();
|
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
|
|
|
// classification head
|
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
GGML_ASSERT(model.cls != nullptr);
|
|
GGML_ASSERT(model.cls_b != nullptr);
|
|
|
|
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
|
|
cur = ggml_tanh(ctx0, cur);
|
|
|
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
if (model.cls_out) {
|
|
GGML_ASSERT(model.cls_out_b != nullptr);
|
|
|
|
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
|
}
|
|
} break;
|
|
default:
|
|
{
|
|
GGML_ABORT("unknown pooling type");
|
|
}
|
|
}
|
|
|
|
cb(cur, "result_embd_pooled", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_tensor * llm_build_pos_bucket(bool causal) {
|
|
if (causal) {
|
|
lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
|
|
} else {
|
|
lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
|
|
}
|
|
|
|
ggml_set_input(lctx.inp_pos_bucket);
|
|
cb(lctx.inp_pos_bucket, "pos_bucket", -1);
|
|
|
|
return lctx.inp_pos_bucket;
|
|
}
|
|
|
|
struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) {
|
|
struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
|
|
cb(pos_bucket_1d, "pos_bucket_1d", -1);
|
|
|
|
struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
|
|
cb(pos_bias, "pos_bias", -1);
|
|
|
|
pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0);
|
|
cb(pos_bias, "pos_bias", -1);
|
|
|
|
pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
|
|
cb(pos_bias, "pos_bias", -1);
|
|
|
|
pos_bias = ggml_cont(ctx0, pos_bias);
|
|
cb(pos_bias, "pos_bias", -1);
|
|
|
|
return pos_bias;
|
|
}
|
|
|
|
struct ggml_tensor * llm_build_inp_embd_enc() {
|
|
const int64_t n_embd = hparams.n_embd;
|
|
lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
|
|
ggml_set_input(lctx.inp_embd_enc);
|
|
cb(lctx.inp_embd_enc, "embd_enc", -1);
|
|
return lctx.inp_embd_enc;
|
|
}
|
|
|
|
struct ggml_tensor * llm_build_inp_KQ_mask_cross() {
|
|
lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
ggml_set_input(lctx.inp_KQ_mask_cross);
|
|
cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
|
|
return lctx.inp_KQ_mask_cross;
|
|
}
|
|
|
|
struct ggml_cgraph * build_llama() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_residual_scale) {
|
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
} else {
|
|
// MoE branch
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, true,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
}
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_residual_scale) {
|
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_logit_scale) {
|
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
|
}
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_deci() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
const int64_t n_head = hparams.n_head(il);
|
|
|
|
if (n_head == 0) {
|
|
// attention-free layer of Llama-3_1-Nemotron-51B
|
|
cur = inpL;
|
|
} else {
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
}
|
|
|
|
if (n_head > 0 && n_head_kv == 0) {
|
|
// "linear attention" of Llama-3_1-Nemotron-51B
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
|
cb(cur, "wo", il);
|
|
} else if (n_head > 0) {
|
|
// self-attention
|
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_residual_scale) {
|
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
}
|
|
|
|
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
|
|
struct ggml_tensor * ffn_inp = cur;
|
|
if (n_head > 0) {
|
|
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
}
|
|
|
|
// feed-forward network
|
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_residual_scale) {
|
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
// For Granite architecture
|
|
if (hparams.f_logit_scale) {
|
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
|
}
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_baichuan() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
switch (model.type) {
|
|
case MODEL_7B:
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
break;
|
|
case MODEL_13B:
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_xverse() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_falcon() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * attn_norm;
|
|
|
|
attn_norm = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(attn_norm, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
if (model.layers[il].attn_norm_2) {
|
|
// Falcon-40B
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm_2,
|
|
model.layers[il].attn_norm_2_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm_2", il);
|
|
} else {
|
|
cur = attn_norm;
|
|
}
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
// using mode = 2 for neox mode
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = cur;
|
|
|
|
// feed forward
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_grok() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// multiply by embedding_multiplier_scale of 78.38367176906169
|
|
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
// Grok
|
|
// if attn_out_norm is present then apply it before adding the input
|
|
if (model.layers[il].attn_out_norm) {
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_out_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_out_norm", il);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
// MoE branch
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_GELU, true,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
|
// Grok
|
|
// if layer_out_norm is present then apply it before adding the input
|
|
// Idea: maybe ffn_out_norm is a better name
|
|
if (model.layers[il].layer_out_norm) {
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].layer_out_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "layer_out_norm", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
// Grok
|
|
// multiply logits by output_multiplier_scale of 0.5773502691896257
|
|
|
|
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_dbrx() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = nullptr;
|
|
struct ggml_tensor * Kcur = nullptr;
|
|
struct ggml_tensor * Vcur = nullptr;
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
cb(cur, "wqkv_clamped", il);
|
|
|
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
// MoE branch
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].attn_out_norm, NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_out_norm", il);
|
|
|
|
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, true,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_starcoder() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
|
cb(pos, "pos_embd", -1);
|
|
|
|
inpL = ggml_add(ctx0, inpL, pos);
|
|
cb(inpL, "inpL", -1);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_refact() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_bert() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
struct ggml_tensor * inp_pos = nullptr;
|
|
|
|
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
|
|
inp_pos = build_inp_pos();
|
|
}
|
|
|
|
// construct input embeddings (token, type, position)
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// token types are hardcoded to zero ("Sentence A")
|
|
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
|
|
inpL = ggml_add(ctx0, inpL, type_row0);
|
|
if (model.arch == LLM_ARCH_BERT) {
|
|
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
|
}
|
|
cb(inpL, "inp_embd", -1);
|
|
|
|
// embed layer norm
|
|
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
|
cb(inpL, "inp_norm", -1);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
|
|
|
|
// iterate layers
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * cur = inpL;
|
|
|
|
struct ggml_tensor * Qcur;
|
|
struct ggml_tensor * Kcur;
|
|
struct ggml_tensor * Vcur;
|
|
|
|
// self-attention
|
|
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
|
|
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
if (model.layers[il].attn_q_norm) {
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm,
|
|
model.layers[il].attn_q_norm_b,
|
|
LLM_NORM, cb, il);
|
|
}
|
|
|
|
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
if (model.layers[il].attn_k_norm) {
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm,
|
|
model.layers[il].attn_k_norm_b,
|
|
LLM_NORM, cb, il);
|
|
}
|
|
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
} else {
|
|
// compute Q and K and RoPE them
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
|
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
cb(kq, "kq", il);
|
|
|
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
|
cb(kq, "kq_soft_max_ext", il);
|
|
|
|
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
|
cb(v, "v", il);
|
|
|
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
|
cb(kqv, "kqv", il);
|
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
|
cb(cur, "kqv_merged_cont", il);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
|
if (model.layers[il].bo) {
|
|
cb(cur, "kqv_wo", il);
|
|
}
|
|
|
|
if (model.layers[il].bo) {
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
|
}
|
|
cb(cur, "kqv_out", il);
|
|
|
|
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// re-add the layer input
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
// attention layer norm
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il);
|
|
|
|
if (model.layers[il].attn_norm_2 != nullptr) {
|
|
cur = ggml_add(ctx0, cur, inpL); // re-add the layer input
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = cur;
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
if (model.arch == LLM_ARCH_BERT) {
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
} else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
|
} else {
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
}
|
|
cb(cur, "ffn_out", il);
|
|
|
|
// attentions bypass the intermediate layer
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
|
// output layer norm
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, cb, il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cb(cur, "result_embd", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_bloom() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
inpL = llm_build_norm(ctx0, inpL, hparams,
|
|
model.tok_norm,
|
|
model.tok_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(inpL, "inp_norm", -1);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// Add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_mpt() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * pos;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
if (model.pos_embd) {
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
|
cb(pos, "pos_embd", -1);
|
|
|
|
inpL = ggml_add(ctx0, inpL, pos);
|
|
cb(inpL, "inpL", -1);
|
|
}
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * attn_norm;
|
|
|
|
attn_norm = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(attn_norm, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = attn_norm;
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
if (model.layers[il].bqkv){
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
}
|
|
|
|
if (hparams.f_clamp_kqv > 0.0f) {
|
|
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
cb(cur, "wqkv_clamped", il);
|
|
}
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
// Q/K Layernorm
|
|
if (model.layers[il].attn_q_norm) {
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm,
|
|
model.layers[il].attn_q_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm,
|
|
model.layers[il].attn_k_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
} else {
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// Add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed forward
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
model.layers[il].ffn_act,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_stablelm() {
|
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
struct ggml_tensor * inpSA = cur;
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
cb(Qcur, "Qcur", il);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
if (model.layers[il].attn_q_norm) {
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm,
|
|
NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
if (model.layers[il].attn_k_norm) {
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm,
|
|
NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
if (model.layers[il].ffn_norm) {
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
} else {
|
|
// parallel residual
|
|
cur = inpSA;
|
|
}
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_qwen() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
// using mode = 2 for neox mode
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward forward
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_qwen2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_qwen2vl() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
|
|
cb(lctx.inp_pos, "inp_pos", -1);
|
|
ggml_set_input(lctx.inp_pos);
|
|
struct ggml_tensor * inp_pos = lctx.inp_pos;
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
int sections[4];
|
|
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_multi(
|
|
ctx0,
|
|
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_multi(
|
|
ctx0,
|
|
ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_qwen2moe() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self_attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// MoE branch
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
ggml_tensor * moe_out =
|
|
llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, false,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
|
// FFN shared expert
|
|
{
|
|
ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
|
|
cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
|
|
|
|
// sigmoid
|
|
ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
|
|
cb(cur_gate, "ffn_shexp_gate", il);
|
|
|
|
ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur_ffn, "ffn_shexp", il);
|
|
|
|
ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
|
|
cb(ffn_shexp_out, "ffn_shexp_out", il);
|
|
|
|
moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
|
|
cb(moe_out, "ffn_out", il);
|
|
|
|
cur = moe_out;
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_phi2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * attn_norm_output;
|
|
struct ggml_tensor * ffn_output;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(attn_norm_output, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = nullptr;
|
|
struct ggml_tensor * Kcur = nullptr;
|
|
struct ggml_tensor * Vcur = nullptr;
|
|
|
|
if (model.layers[il].wqkv) {
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
} else {
|
|
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
|
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
|
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
|
}
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
// with phi2, we scale the Q to avoid precision issues
|
|
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
|
|
Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
|
|
}
|
|
|
|
// FF
|
|
{
|
|
ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(ffn_output, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_output);
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output_no_bias", -1);
|
|
|
|
cur = ggml_add(ctx0, cur, model.output_b);
|
|
cb(cur, "result_output", -1);
|
|
ggml_build_forward_expand(gf, cur);
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_phi3() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = nullptr;
|
|
if (hparams.n_swa == 0) {
|
|
// Phi-4 doesn't use sliding window attention
|
|
KQ_mask = build_inp_KQ_mask();
|
|
} else {
|
|
KQ_mask = build_inp_KQ_mask_swa();
|
|
}
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
auto residual = inpL;
|
|
|
|
// self-attention
|
|
{
|
|
// rope freq factors for 128k context
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(attn_norm_output, "attn_norm", il);
|
|
|
|
struct ggml_tensor * Qcur = nullptr;
|
|
struct ggml_tensor * Kcur = nullptr;
|
|
struct ggml_tensor * Vcur = nullptr;
|
|
|
|
if (model.layers[il].wqkv) {
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
|
|
cb(cur, "wqkv", il);
|
|
|
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
|
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
|
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
|
|
}
|
|
else {
|
|
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
|
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
|
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
|
}
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor* inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, residual);
|
|
residual = cur;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// FF
|
|
// special-case: the up and gate tensors are merged into a single tensor
|
|
// TOOD: support into llm_build_ffn
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, residual, cur);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
|
|
struct ggml_cgraph * build_plamo() {
|
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
struct ggml_tensor * attention_norm = cur;
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
|
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
struct ggml_tensor * sa_out = cur;
|
|
|
|
cur = attention_norm;
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, sa_out);
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_gpt2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * pos;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
|
cb(pos, "pos_embd", -1);
|
|
|
|
inpL = ggml_add(ctx0, inpL, pos);
|
|
cb(inpL, "inpL", -1);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_codeshell() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(tmpq, "tmpq", il);
|
|
cb(tmpk, "tmpk", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_orion() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
// if (model.layers[il].bq) {
|
|
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
// cb(Qcur, "Qcur", il);
|
|
// }
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
// if (model.layers[il].bk) {
|
|
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
// cb(Kcur, "Kcur", il);
|
|
// }
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
// if (model.layers[il].bv) {
|
|
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
// cb(Vcur, "Vcur", il);
|
|
// }
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_internlm2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_minicpm3() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
//TODO: if the model varies, these parameters need to be read from the model
|
|
const int64_t n_embd_base = 256;
|
|
const float scale_embd = 12.0f;
|
|
const float scale_depth = 1.4f;
|
|
const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
|
|
|
|
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
|
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// scale the input embeddings
|
|
inpL = ggml_scale(ctx0, inpL, scale_embd);
|
|
cb(inpL, "inp_scaled", -1);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self_attention
|
|
{
|
|
struct ggml_tensor * q = NULL;
|
|
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
|
|
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
|
cb(q, "q", il);
|
|
|
|
q = llm_build_norm(ctx0, q, hparams,
|
|
model.layers[il].attn_q_a_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(q, "q", il);
|
|
|
|
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
|
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
|
cb(q, "q", il);
|
|
|
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
|
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
|
0);
|
|
cb(q_nope, "q_nope", il);
|
|
|
|
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
|
ggml_row_size(q->type, n_embd_head_qk_nope));
|
|
cb(q_pe, "q_pe", il);
|
|
|
|
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
|
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
|
|
|
|
// split into {kv_lora_rank, n_tokens}
|
|
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
|
|
kv_pe_compresseed->nb[1],
|
|
0);
|
|
cb(kv_compressed, "kv_compressed", il);
|
|
|
|
// and {n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
|
|
kv_pe_compresseed->nb[1],
|
|
kv_pe_compresseed->nb[1],
|
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
|
cb(k_pe, "k_pe", il);
|
|
|
|
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
|
model.layers[il].attn_kv_a_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(kv_compressed, "kv_compressed", il);
|
|
|
|
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
|
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
|
cb(kv, "kv", il);
|
|
|
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
|
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
|
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
|
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
|
0);
|
|
cb(k_nope, "k_nope", il);
|
|
|
|
// and {n_head * n_embd_head_v, n_tokens}
|
|
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
|
cb(v_states, "v_states", il);
|
|
|
|
v_states = ggml_cont(ctx0, v_states);
|
|
cb(v_states, "v_states", il);
|
|
|
|
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
|
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
|
0);
|
|
cb(v_states, "v_states", il);
|
|
|
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
|
q_pe = ggml_rope_ext(
|
|
ctx0, q_pe, inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(q_pe, "q_pe", il);
|
|
|
|
// shared RoPE key
|
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
|
k_pe = ggml_rope_ext(
|
|
ctx0, k_pe, inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(k_pe, "k_pe", il);
|
|
|
|
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
|
cb(q_states, "q_states", il);
|
|
|
|
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
|
cb(k_states, "k_states", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
// scale_res - scale the hidden states for residual connection
|
|
const float scale_res = scale_depth/sqrtf(float(n_layer));
|
|
cur = ggml_scale(ctx0, cur, scale_res);
|
|
cb(cur, "hidden_scaled", il);
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
// scale the hidden states for residual connection
|
|
cur = ggml_scale(ctx0, cur, scale_res);
|
|
cb(cur, "hidden_scaled_ffn", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head scaling
|
|
const float scale_lmhead = float(n_embd_base)/float(n_embd);
|
|
cur = ggml_scale(ctx0, cur, scale_lmhead);
|
|
cb(cur, "lmhead_scaling", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_gemma() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
|
cb(inpL, "inp_scaled", -1);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
|
|
cb(Qcur, "Qcur_scaled", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
|
cb(sa_out, "sa_out", il);
|
|
|
|
cur = llm_build_norm(ctx0, sa_out, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, sa_out);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_gemma2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
|
cb(inpL, "inp_scaled", -1);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
// gemma 2 requires different mask for layers using sliding window (SWA)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
|
|
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
// (il % 2) layers use SWA
|
|
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
|
switch (model.type) {
|
|
case llm_type::MODEL_2B:
|
|
case llm_type::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
|
|
case llm_type::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
|
|
default: GGML_ABORT("fatal error");
|
|
};
|
|
cb(Qcur, "Qcur_scaled", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_post_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_post_norm", il);
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
|
cb(sa_out, "sa_out", il);
|
|
|
|
cur = llm_build_norm(ctx0, sa_out, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].ffn_post_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "ffn_post_norm", -1);
|
|
|
|
cur = ggml_add(ctx0, cur, sa_out);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
// final logit soft-capping
|
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
|
cur = ggml_tanh(ctx0, cur);
|
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
|
|
struct ggml_cgraph * build_starcoder2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_mamba() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
// {n_embd, n_tokens}
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
struct ggml_tensor * state_copy = build_inp_s_copy();
|
|
struct ggml_tensor * state_mask = build_inp_s_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
|
|
state_copy, state_mask,
|
|
kv_head, n_kv, cb, il);
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// residual
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
// final rmsnorm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_command_r() {
|
|
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
const float f_logit_scale = hparams.f_logit_scale;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
struct ggml_tensor * ffn_inp = cur;
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
if (model.layers[il].attn_q_norm) {
|
|
Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
|
|
ggml_element_size(Qcur) * n_embd_head,
|
|
ggml_element_size(Qcur) * n_embd_head * n_head,
|
|
0);
|
|
cb(Qcur, "Qcur", il);
|
|
Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
|
|
ggml_element_size(Kcur) * n_embd_head,
|
|
ggml_element_size(Kcur) * n_embd_head * n_head_kv,
|
|
0);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm,
|
|
NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm,
|
|
NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * attn_out = cur;
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, ffn_inp,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
// add together residual + FFN + self-attention
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = ggml_add(ctx0, cur, attn_out);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
if (f_logit_scale) {
|
|
cur = ggml_scale(ctx0, cur, f_logit_scale);
|
|
}
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
|
|
}
|
|
|
|
struct ggml_cgraph * build_cohere2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
const float f_logit_scale = hparams.f_logit_scale;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
// cohere2 requires different mask for layers using sliding window (SWA)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
|
|
|
|
// sliding window switch pattern
|
|
const int32_t sliding_window_pattern = 4;
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
// three layers sliding window attention (window size 4096) and ROPE
|
|
// fourth layer uses global attention without positional embeddings
|
|
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
|
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
struct ggml_tensor * ffn_inp = cur;
|
|
|
|
// self-attention
|
|
{
|
|
// rope freq factors for 128k context
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
if (is_sliding) {
|
|
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
|
beta_fast, beta_slow);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
|
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
|
attn_factor, beta_fast, beta_slow);
|
|
cb(Kcur, "Kcur", il);
|
|
} else {
|
|
// For non-sliding layers, just reshape without applying RoPE
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
|
|
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * attn_out = cur;
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
|
|
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
|
|
cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
// add together residual + FFN + self-attention
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = ggml_add(ctx0, cur, attn_out);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
if (f_logit_scale) {
|
|
cur = ggml_scale(ctx0, cur, f_logit_scale);
|
|
}
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
// ref: https://allenai.org/olmo
|
|
// based on the original build_llama() function, changes:
|
|
// * non-parametric layer norm
|
|
// * clamp qkv
|
|
// * removed bias
|
|
// * removed MoE
|
|
struct ggml_cgraph * build_olmo() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
NULL, NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (hparams.f_clamp_kqv > 0.0f) {
|
|
Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (hparams.f_clamp_kqv > 0.0f) {
|
|
Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (hparams.f_clamp_kqv > 0.0f) {
|
|
Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, nullptr,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
NULL, NULL,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
NULL, NULL,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_olmo2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = inpL;
|
|
|
|
// self_attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Qcur, "Qcur_normed", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Kcur, "Kcur_normed", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur_rope", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur_rope", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_post_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_post_norm", il);
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_ffn(ctx0, lctx, ffn_inp,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].ffn_post_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "ffn_post_norm", -1);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
// based on the build_qwen2moe() function, changes:
|
|
// * removed shared experts
|
|
// * removed bias
|
|
// * added q, k norm
|
|
struct ggml_cgraph * build_olmoe() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self_attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Qcur, "Qcur_normed", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Kcur, "Kcur_normed", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur_rope", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur_rope", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// MoE branch
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, false,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_openelm() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
const int64_t n_head = hparams.n_head(il);
|
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
const int64_t n_head_qkv = 2*n_head_kv + n_head;
|
|
|
|
cur = inpL;
|
|
struct ggml_tensor * residual = cur;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
|
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
|
|
cb(Qcur, "Vcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_gptneox() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// ffn
|
|
if (hparams.use_par_res) {
|
|
// attention and ffn are computed in parallel
|
|
// x = x + attn(ln1(x)) + ffn(ln2(x))
|
|
|
|
struct ggml_tensor * attn_out = cur;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, attn_out);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
} else {
|
|
// attention and ffn are computed sequentially
|
|
// x = x + attn(ln1(x))
|
|
// x = x + ffn(ln2(x))
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_arctic() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(ffn_out, "ffn_out", il);
|
|
|
|
// MoE
|
|
cur = llm_build_norm(ctx0, inpSA, hparams,
|
|
model.layers[il].ffn_norm_exps, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm_exps", il);
|
|
|
|
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, true,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_out);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_deepseek() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
} else {
|
|
// MoE branch
|
|
ggml_tensor * moe_out =
|
|
llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, false,
|
|
false, hparams.expert_weights_scale,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
cb, il);
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
|
|
// FFN shared expert
|
|
{
|
|
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(ffn_shexp, "ffn_shexp", il);
|
|
|
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_deepseek2() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
bool is_lite = (hparams.n_layer == 27);
|
|
|
|
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
|
|
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
|
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
|
|
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
|
|
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
|
|
|
|
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
|
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
// {n_embd, n_tokens}
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self_attention
|
|
{
|
|
struct ggml_tensor * q = NULL;
|
|
if (!is_lite) {
|
|
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
|
|
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
|
cb(q, "q", il);
|
|
|
|
q = llm_build_norm(ctx0, q, hparams,
|
|
model.layers[il].attn_q_a_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(q, "q", il);
|
|
|
|
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
|
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
|
cb(q, "q", il);
|
|
} else {
|
|
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
|
cb(q, "q", il);
|
|
}
|
|
|
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
|
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
|
0);
|
|
cb(q_nope, "q_nope", il);
|
|
|
|
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
|
ggml_row_size(q->type, n_embd_head_qk_nope));
|
|
cb(q_pe, "q_pe", il);
|
|
|
|
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
|
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
|
|
|
|
// split into {kv_lora_rank, n_tokens}
|
|
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
|
|
kv_pe_compresseed->nb[1],
|
|
0);
|
|
cb(kv_compressed, "kv_compressed", il);
|
|
|
|
// and {n_embd_head_qk_rope, n_tokens}
|
|
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
|
|
kv_pe_compresseed->nb[1],
|
|
kv_pe_compresseed->nb[1],
|
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
|
cb(k_pe, "k_pe", il);
|
|
|
|
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
|
model.layers[il].attn_kv_a_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(kv_compressed, "kv_compressed", il);
|
|
|
|
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
|
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
|
cb(kv, "kv", il);
|
|
|
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
|
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
|
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
|
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
|
0);
|
|
cb(k_nope, "k_nope", il);
|
|
|
|
// and {n_head * n_embd_head_v, n_tokens}
|
|
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
|
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
|
cb(v_states, "v_states", il);
|
|
|
|
v_states = ggml_cont(ctx0, v_states);
|
|
cb(v_states, "v_states", il);
|
|
|
|
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
|
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
|
0);
|
|
cb(v_states, "v_states", il);
|
|
|
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
|
q_pe = ggml_rope_ext(
|
|
ctx0, q_pe, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
|
);
|
|
cb(q_pe, "q_pe", il);
|
|
|
|
// shared RoPE key
|
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
|
k_pe = ggml_rope_ext(
|
|
ctx0, k_pe, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
|
);
|
|
cb(k_pe, "k_pe", il);
|
|
|
|
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
|
cb(q_states, "q_states", il);
|
|
|
|
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
|
cb(k_states, "k_states", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
} else {
|
|
// MoE branch
|
|
ggml_tensor * moe_out =
|
|
llm_build_moe_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
model.layers[il].ffn_exp_probs_b,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_SILU, hparams.expert_weights_norm,
|
|
true, hparams.expert_weights_scale,
|
|
(enum llama_expert_gating_func_type) hparams.expert_gating_func,
|
|
cb, il);
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
|
|
// FFN shared expert
|
|
{
|
|
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(ffn_shexp, "ffn_shexp", il);
|
|
|
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_bitnet() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
if (model.layers[il].wq_scale) {
|
|
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
|
|
}
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
// B1.K
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
if (model.layers[il].wk_scale) {
|
|
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
|
|
}
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
// B1.V
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
if (model.layers[il].wv_scale) {
|
|
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
|
|
}
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
NULL, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_sub_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_sub_norm", il);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
|
if (model.layers[il].wo_scale) {
|
|
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
|
|
}
|
|
if (model.layers[il].bo) {
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
|
}
|
|
cb(cur, "attn_o_out", il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward forward
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale,
|
|
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
|
|
NULL, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_sub_out", il);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].ffn_sub_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_sub_norm", il);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
|
|
if (model.layers[il].ffn_down_scale) {
|
|
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
|
|
}
|
|
cb(cur, "ffn_down", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
// FIXME: do not use model.tok_embd directly, duplicate as model.output
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_t5_enc() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
GGML_ASSERT(lctx.is_encoding);
|
|
struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm_enc, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
|
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
cb(kq, "kq", il);
|
|
|
|
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
|
|
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
|
|
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
|
cb(kq_b, "kq_b", il);
|
|
|
|
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
|
|
cb(kq, "kq_soft_max_ext", il);
|
|
|
|
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
|
cb(v, "v", il);
|
|
|
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
|
cb(kqv, "kqv", il);
|
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
|
cb(cur, "kqv_merged_cont", il);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
|
|
cb(cur, "kqv_out", il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm_enc, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// T5 uses relu, flan-T5 uses gelu-gated
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up_enc, NULL, NULL,
|
|
model.layers[il].ffn_gate_enc, NULL, NULL,
|
|
model.layers[il].ffn_down_enc, NULL, NULL,
|
|
NULL,
|
|
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
|
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
|
cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
|
if (layer_dir != nullptr) {
|
|
cur = ggml_add(ctx0, cur, layer_dir);
|
|
}
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
cb(cur, "result_embd", -1);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm_enc, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_t5_dec() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
GGML_ASSERT(!lctx.is_encoding);
|
|
GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
|
|
|
|
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
|
|
struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
|
|
|
|
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
|
|
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
|
|
|
struct ggml_tensor * k =
|
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
|
n_embd_head_k, n_kv, n_head_kv,
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
|
0);
|
|
cb(k, "k", il);
|
|
|
|
struct ggml_tensor * v =
|
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
|
n_kv, n_embd_head_v, n_head_kv,
|
|
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
|
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
|
0);
|
|
cb(v, "v", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
cb(kq, "kq", il);
|
|
|
|
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
|
|
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
|
|
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
|
cb(kq_b, "kq_b", il);
|
|
|
|
kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
|
|
cb(kq, "kq_soft_max_ext", il);
|
|
|
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
cb(kqv, "kqv", il);
|
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
|
cb(cur, "kqv_merged_cont", il);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
|
cb(cur, "kqv_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, inpSA);
|
|
cb(cur, "cross_inp", il);
|
|
|
|
struct ggml_tensor * inpCA = cur;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_norm_cross, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm_cross", il);
|
|
|
|
// cross-attention
|
|
{
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
|
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
cb(kq, "kq", il);
|
|
|
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
|
|
cb(kq, "kq_soft_max_ext", il);
|
|
|
|
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
|
|
cb(v, "v", il);
|
|
|
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
|
|
cb(kqv, "kqv", il);
|
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
|
cb(cur, "kqv_merged_cont", il);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
|
|
cb(cur, "kqv_out", il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// T5 uses relu, flan-T5 uses gelu-gated
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
|
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
|
cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
|
if (layer_dir != nullptr) {
|
|
cur = ggml_add(ctx0, cur, layer_dir);
|
|
}
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
cb(cur, "result_embd", -1);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_jais() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
|
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
|
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
}
|
|
|
|
// add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
}
|
|
|
|
inpL = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(inpL, "l_out", il);
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_chatglm() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
struct ggml_tensor * Qcur = nullptr;
|
|
struct ggml_tensor * Kcur = nullptr;
|
|
struct ggml_tensor * Vcur = nullptr;
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
|
cb(cur, "wqkv", il);
|
|
|
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
|
cb(cur, "bqkv", il);
|
|
|
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur_rope", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur_rope", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, NULL,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
// Add the input
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// FF
|
|
{
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
}
|
|
|
|
inpL = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(inpL, "l_out", il);
|
|
}
|
|
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.output_norm,
|
|
NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_nemotron() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
//GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm,
|
|
model.layers[il].attn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm,
|
|
model.layers[il].ffn_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
|
NULL, NULL, NULL,
|
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
|
NULL,
|
|
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_exaone() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
ggml_cgraph * build_rwkv6() {
|
|
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// Token shift state dimensions should be 2 * n_emb
|
|
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
|
|
|
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
const int64_t n_tokens = ubatch.n_tokens;
|
|
GGML_ASSERT(n_seqs != 0);
|
|
GGML_ASSERT(ubatch.equal_seqs);
|
|
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
struct ggml_tensor * state_copy = build_inp_s_copy();
|
|
struct ggml_tensor * state_mask = build_inp_s_mask();
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
const llama_layer * layer = &model.layers[il];
|
|
|
|
// (ab)using the KV cache to store the states
|
|
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
|
|
gf, kv_self.k_l[il], state_copy, state_mask,
|
|
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
|
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
|
|
gf, kv_self.v_l[il], state_copy, state_mask,
|
|
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
|
|
|
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
|
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
|
|
|
|
struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
|
struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
|
|
|
struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
|
|
struct ggml_tensor * x_prev = ggml_concat(
|
|
ctx0,
|
|
att_shift,
|
|
ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
|
|
1
|
|
);
|
|
|
|
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
|
|
ggml_build_forward_expand(gf, cur);
|
|
ggml_build_forward_expand(
|
|
gf,
|
|
ggml_cpy(
|
|
ctx0,
|
|
wkv_states,
|
|
ggml_view_1d(
|
|
ctx0,
|
|
kv_self.v_l[il],
|
|
hparams.n_embd_v_s() * n_seqs,
|
|
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
|
)
|
|
)
|
|
);
|
|
|
|
struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
|
|
x_prev = ggml_concat(
|
|
ctx0,
|
|
ffn_shift,
|
|
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
|
|
1
|
|
);
|
|
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
|
|
struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
|
|
|
|
token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
|
|
|
|
ggml_build_forward_expand(
|
|
gf,
|
|
ggml_cpy(
|
|
ctx0,
|
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
|
|
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
|
|
)
|
|
);
|
|
|
|
if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
|
|
cur = ggml_scale(ctx0, cur, 0.5F);
|
|
}
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
// ref: https://github.com/facebookresearch/chameleon
|
|
// based on the original build_llama() function, changes:
|
|
// * qk-norm
|
|
// * swin-norm
|
|
// * removed bias
|
|
// * removed MoE
|
|
struct ggml_cgraph * build_chameleon() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
int32_t n_tokens = this->n_tokens;
|
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
// inp_pos - contains the positions
|
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
if (hparams.swin_norm) {
|
|
cur = inpL;
|
|
} else {
|
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "attn_norm", il);
|
|
}
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
if (model.layers[il].attn_q_norm) {
|
|
Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
|
|
ggml_element_size(Qcur) * n_embd_head,
|
|
ggml_element_size(Qcur) * n_embd_head * n_head,
|
|
0);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
|
model.layers[il].attn_q_norm,
|
|
model.layers[il].attn_q_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
|
|
if (model.layers[il].attn_k_norm) {
|
|
Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
|
|
ggml_element_size(Kcur) * n_embd_head,
|
|
ggml_element_size(Kcur) * n_embd_head * n_head_kv,
|
|
0);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
|
model.layers[il].attn_k_norm,
|
|
model.layers[il].attn_k_norm_b,
|
|
LLM_NORM, cb, il);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
cb(Kcur, "Kcur", il);
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
model.layers[il].wo, nullptr,
|
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
|
if (hparams.swin_norm) {
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
}
|
|
}
|
|
|
|
if (il == n_layer - 1) {
|
|
// skip computing output for unused tokens
|
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
n_tokens = n_outputs;
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
if (!hparams.swin_norm) {
|
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
}
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
if (hparams.swin_norm) {
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, cb, il);
|
|
cb(cur, "ffn_norm", il);
|
|
}
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, cb, -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
cb(cur, "result_output_with_img_logits", -1);
|
|
|
|
// TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
|
|
// Needs to be removed once image outputs are supported.
|
|
int img_token_end_idx = 8196;
|
|
int img_token_start_idx = 4;
|
|
int num_img_tokens = img_token_end_idx - img_token_start_idx;
|
|
// creates 1d tensor of size num_img_tokens and values -FLT_MAX,
|
|
// which ensures that text token values are always at least larger than image token values
|
|
struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
|
|
img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
|
|
cb(img_logits, "img_logits", -1);
|
|
cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
|
|
cb(cur, "result_output", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
|
|
struct ggml_cgraph * build_wavtokenizer_dec() {
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
struct ggml_tensor * cur;
|
|
struct ggml_tensor * inpL;
|
|
|
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
|
|
|
|
cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
|
|
cur = ggml_add(ctx0, cur, model.conv1d_b);
|
|
|
|
// posnet
|
|
for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
|
|
const auto & layer = model.layers[il].posnet;
|
|
|
|
inpL = cur;
|
|
|
|
switch (il) {
|
|
case 0:
|
|
case 1:
|
|
case 3:
|
|
case 4:
|
|
{
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
layer.norm1,
|
|
layer.norm1_b,
|
|
LLM_NORM_GROUP, cb, 0);
|
|
|
|
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
|
|
|
|
cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
|
|
cur = ggml_add(ctx0, cur, layer.conv1_b);
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
layer.norm2,
|
|
layer.norm2_b,
|
|
LLM_NORM_GROUP, cb, 0);
|
|
|
|
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
|
|
|
|
cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
|
|
cur = ggml_add(ctx0, cur, layer.conv2_b);
|
|
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
} break;
|
|
case 2:
|
|
{
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
layer.attn_norm,
|
|
layer.attn_norm_b,
|
|
LLM_NORM_GROUP, cb, 0);
|
|
|
|
struct ggml_tensor * q;
|
|
struct ggml_tensor * k;
|
|
struct ggml_tensor * v;
|
|
|
|
q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
|
|
k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
|
|
v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
|
|
|
|
q = ggml_add(ctx0, q, layer.attn_q_b);
|
|
k = ggml_add(ctx0, k, layer.attn_k_b);
|
|
v = ggml_add(ctx0, v, layer.attn_v_b);
|
|
|
|
q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
|
|
k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
|
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
|
|
kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
|
|
|
|
cur = ggml_mul_mat(ctx0, kq, v);
|
|
|
|
cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
|
|
cur = ggml_add(ctx0, cur, layer.attn_o_b);
|
|
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
} break;
|
|
case 5:
|
|
{
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
layer.norm,
|
|
layer.norm_b,
|
|
LLM_NORM_GROUP, cb, 0);
|
|
} break;
|
|
default: GGML_ABORT("unknown posnet layer");
|
|
};
|
|
}
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.tok_norm,
|
|
model.tok_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
|
inpL = cur;
|
|
|
|
// convnext
|
|
for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
|
|
const auto & layer = model.layers[il].convnext;
|
|
|
|
cur = inpL;
|
|
|
|
cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
|
|
cur = ggml_add(ctx0, cur, layer.dw_b);
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
layer.norm,
|
|
layer.norm_b,
|
|
LLM_NORM, cb, -1);
|
|
|
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
|
layer.pw1, layer.pw1_b, NULL,
|
|
NULL, NULL, NULL,
|
|
layer.pw2, layer.pw2_b, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
|
|
|
cur = ggml_mul(ctx0, cur, layer.gamma);
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
|
inpL = ggml_add(ctx0, cur, inpL);
|
|
}
|
|
|
|
cur = inpL;
|
|
|
|
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
model.output_norm,
|
|
model.output_norm_b,
|
|
LLM_NORM, cb, -1);
|
|
|
|
// lm_head
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
|
cur = ggml_add(ctx0, cur, model.output_b);
|
|
cb(cur, "result_embd", -1);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|
|
};
|
|
|
|
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
|
llama_ubatch dummy = {};
|
|
dummy.equal_seqs = true;
|
|
|
|
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
|
|
|
struct llm_build_context llm(lctx, dummy, cb, false);
|
|
|
|
llm.init();
|
|
|
|
struct ggml_cgraph * result = llm.build_defrag(ids);
|
|
|
|
llm.free();
|
|
|
|
return result;
|
|
}
|
|
|
|
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
|
llama_ubatch dummy = {};
|
|
dummy.equal_seqs = true;
|
|
|
|
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
|
|
|
struct llm_build_context llm(lctx, dummy, cb, false);
|
|
|
|
llm.init();
|
|
|
|
struct ggml_cgraph * result = llm.build_k_shift();
|
|
|
|
llm.free();
|
|
|
|
return result;
|
|
}
|
|
|
|
static struct ggml_cgraph * llama_build_graph(
|
|
llama_context & lctx,
|
|
const llama_ubatch & ubatch,
|
|
bool worst_case) {
|
|
const auto & model = lctx.model;
|
|
|
|
// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
|
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
|
|
if (il >= 0) {
|
|
ggml_format_name(cur, "%s-%d", name, il);
|
|
} else {
|
|
ggml_set_name(cur, name);
|
|
}
|
|
|
|
if (!lctx.cparams.offload_kqv) {
|
|
if (strcmp(name, "kqv_merged_cont") == 0) {
|
|
// all nodes between the KV store and the attention output are run on the CPU
|
|
ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, lctx.backend_cpu);
|
|
}
|
|
}
|
|
|
|
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
// FIXME: fix in ggml_backend_sched
|
|
const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
|
|
if (ubatch.n_tokens < 32 || full_offload) {
|
|
if (il != -1 && strcmp(name, "norm") == 0) {
|
|
const auto & dev_layer = lctx.model.dev_layer.at(il);
|
|
for (auto & backend : lctx.backends) {
|
|
if (ggml_backend_get_device(backend.get()) == dev_layer.dev) {
|
|
if (ggml_backend_supports_op(backend.get(), cur)) {
|
|
ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct ggml_cgraph * result = NULL;
|
|
|
|
struct llm_build_context llm(lctx, ubatch, cb, worst_case);
|
|
|
|
llm.init();
|
|
|
|
switch (model.arch) {
|
|
case LLM_ARCH_LLAMA:
|
|
case LLM_ARCH_MINICPM:
|
|
case LLM_ARCH_GRANITE:
|
|
case LLM_ARCH_GRANITE_MOE:
|
|
{
|
|
result = llm.build_llama();
|
|
} break;
|
|
case LLM_ARCH_DECI:
|
|
{
|
|
result = llm.build_deci();
|
|
} break;
|
|
case LLM_ARCH_BAICHUAN:
|
|
{
|
|
result = llm.build_baichuan();
|
|
} break;
|
|
case LLM_ARCH_FALCON:
|
|
{
|
|
result = llm.build_falcon();
|
|
} break;
|
|
case LLM_ARCH_GROK:
|
|
{
|
|
result = llm.build_grok();
|
|
} break;
|
|
case LLM_ARCH_STARCODER:
|
|
{
|
|
result = llm.build_starcoder();
|
|
} break;
|
|
case LLM_ARCH_REFACT:
|
|
{
|
|
result = llm.build_refact();
|
|
} break;
|
|
case LLM_ARCH_BERT:
|
|
case LLM_ARCH_JINA_BERT_V2:
|
|
case LLM_ARCH_NOMIC_BERT:
|
|
{
|
|
result = llm.build_bert();
|
|
} break;
|
|
case LLM_ARCH_BLOOM:
|
|
{
|
|
result = llm.build_bloom();
|
|
} break;
|
|
case LLM_ARCH_MPT:
|
|
{
|
|
result = llm.build_mpt();
|
|
} break;
|
|
case LLM_ARCH_STABLELM:
|
|
{
|
|
result = llm.build_stablelm();
|
|
} break;
|
|
case LLM_ARCH_QWEN:
|
|
{
|
|
result = llm.build_qwen();
|
|
} break;
|
|
case LLM_ARCH_QWEN2:
|
|
{
|
|
result = llm.build_qwen2();
|
|
} break;
|
|
case LLM_ARCH_QWEN2VL:
|
|
{
|
|
lctx.n_pos_per_token = 4;
|
|
result = llm.build_qwen2vl();
|
|
} break;
|
|
case LLM_ARCH_QWEN2MOE:
|
|
{
|
|
result = llm.build_qwen2moe();
|
|
} break;
|
|
case LLM_ARCH_PHI2:
|
|
{
|
|
result = llm.build_phi2();
|
|
} break;
|
|
case LLM_ARCH_PHI3:
|
|
{
|
|
result = llm.build_phi3();
|
|
} break;
|
|
case LLM_ARCH_PLAMO:
|
|
{
|
|
result = llm.build_plamo();
|
|
} break;
|
|
case LLM_ARCH_GPT2:
|
|
{
|
|
result = llm.build_gpt2();
|
|
} break;
|
|
case LLM_ARCH_CODESHELL:
|
|
{
|
|
result = llm.build_codeshell();
|
|
} break;
|
|
case LLM_ARCH_ORION:
|
|
{
|
|
result = llm.build_orion();
|
|
} break;
|
|
case LLM_ARCH_INTERNLM2:
|
|
{
|
|
result = llm.build_internlm2();
|
|
} break;
|
|
case LLM_ARCH_MINICPM3:
|
|
{
|
|
result = llm.build_minicpm3();
|
|
} break;
|
|
case LLM_ARCH_GEMMA:
|
|
{
|
|
result = llm.build_gemma();
|
|
} break;
|
|
case LLM_ARCH_GEMMA2:
|
|
{
|
|
result = llm.build_gemma2();
|
|
} break;
|
|
case LLM_ARCH_STARCODER2:
|
|
{
|
|
result = llm.build_starcoder2();
|
|
} break;
|
|
case LLM_ARCH_MAMBA:
|
|
{
|
|
result = llm.build_mamba();
|
|
} break;
|
|
case LLM_ARCH_XVERSE:
|
|
{
|
|
result = llm.build_xverse();
|
|
} break;
|
|
case LLM_ARCH_COMMAND_R:
|
|
{
|
|
result = llm.build_command_r();
|
|
} break;
|
|
case LLM_ARCH_COHERE2:
|
|
{
|
|
result = llm.build_cohere2();
|
|
} break;
|
|
case LLM_ARCH_DBRX:
|
|
{
|
|
result = llm.build_dbrx();
|
|
} break;
|
|
case LLM_ARCH_OLMO:
|
|
{
|
|
result = llm.build_olmo();
|
|
} break;
|
|
case LLM_ARCH_OLMO2:
|
|
{
|
|
result = llm.build_olmo2();
|
|
} break;
|
|
case LLM_ARCH_OLMOE:
|
|
{
|
|
result = llm.build_olmoe();
|
|
} break;
|
|
case LLM_ARCH_OPENELM:
|
|
{
|
|
result = llm.build_openelm();
|
|
} break;
|
|
case LLM_ARCH_GPTNEOX:
|
|
{
|
|
result = llm.build_gptneox();
|
|
} break;
|
|
case LLM_ARCH_ARCTIC:
|
|
{
|
|
result = llm.build_arctic();
|
|
} break;
|
|
case LLM_ARCH_DEEPSEEK:
|
|
{
|
|
result = llm.build_deepseek();
|
|
} break;
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
{
|
|
result = llm.build_deepseek2();
|
|
} break;
|
|
case LLM_ARCH_CHATGLM:
|
|
{
|
|
result = llm.build_chatglm();
|
|
} break;
|
|
case LLM_ARCH_BITNET:
|
|
{
|
|
result = llm.build_bitnet();
|
|
} break;
|
|
case LLM_ARCH_T5:
|
|
{
|
|
if (lctx.is_encoding) {
|
|
result = llm.build_t5_enc();
|
|
} else {
|
|
result = llm.build_t5_dec();
|
|
}
|
|
} break;
|
|
case LLM_ARCH_T5ENCODER:
|
|
{
|
|
result = llm.build_t5_enc();
|
|
} break;
|
|
case LLM_ARCH_JAIS:
|
|
{
|
|
result = llm.build_jais();
|
|
} break;
|
|
case LLM_ARCH_NEMOTRON:
|
|
{
|
|
result = llm.build_nemotron();
|
|
} break;
|
|
case LLM_ARCH_EXAONE:
|
|
{
|
|
result = llm.build_exaone();
|
|
} break;
|
|
case LLM_ARCH_RWKV6:
|
|
{
|
|
result = llm.build_rwkv6();
|
|
} break;
|
|
case LLM_ARCH_CHAMELEON:
|
|
{
|
|
result = llm.build_chameleon();
|
|
} break;
|
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
|
{
|
|
result = llm.build_wavtokenizer_dec();
|
|
} break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
|
|
// add on pooling layer
|
|
if (lctx.cparams.embeddings) {
|
|
result = llm.append_pooling(result);
|
|
}
|
|
|
|
llm.free();
|
|
|
|
return result;
|
|
}
|
|
|
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
static enum ggml_status llama_graph_compute(
|
|
llama_context & lctx,
|
|
ggml_cgraph * gf,
|
|
int n_threads,
|
|
ggml_threadpool * threadpool) {
|
|
if (lctx.backend_cpu != nullptr) {
|
|
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu));
|
|
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
|
|
set_threadpool_fn(lctx.backend_cpu, threadpool);
|
|
}
|
|
|
|
// set the number of threads for all the backends
|
|
for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
|
|
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
|
}
|
|
|
|
auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
|
|
if (status != GGML_STATUS_SUCCESS) {
|
|
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
|
}
|
|
|
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
|
|
|
return status;
|
|
}
|
|
|
|
// decode a batch of tokens by evaluating the transformer
|
|
// in case of unsuccessful decoding (error or warning),
|
|
// the kv_cache state will be returned to its original state
|
|
// (for non-recurrent models) or cleaned (for recurrent models)
|
|
//
|
|
// - lctx: llama context
|
|
// - batch: batch to evaluate
|
|
//
|
|
// return 0 on success
|
|
// return positive int on warning
|
|
// return negative int on error
|
|
//
|
|
static int llama_decode_impl(
|
|
llama_context & lctx,
|
|
llama_batch inp_batch) {
|
|
|
|
lctx.is_encoding = false;
|
|
|
|
if (inp_batch.n_tokens == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
return -1;
|
|
}
|
|
|
|
// temporary allocate memory for the input batch if needed
|
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
|
|
|
const llama_batch & batch = batch_allocr.batch;
|
|
const uint32_t n_tokens_all = batch.n_tokens;
|
|
|
|
const auto & model = lctx.model;
|
|
const auto & hparams = model.hparams;
|
|
const auto & cparams = lctx.cparams;
|
|
|
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
|
|
if (batch.token) {
|
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
|
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
return -1;
|
|
}
|
|
}
|
|
}
|
|
|
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
|
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
|
|
if (lctx.t_compute_start_us == 0) {
|
|
lctx.t_compute_start_us = ggml_time_us();
|
|
}
|
|
lctx.n_queued_tokens += n_tokens_all;
|
|
|
|
auto & kv_self = lctx.kv_self;
|
|
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
|
|
|
const int64_t n_embd = hparams.n_embd;
|
|
const int64_t n_vocab = hparams.n_vocab;
|
|
|
|
uint32_t n_outputs = 0;
|
|
uint32_t n_outputs_prev = 0;
|
|
|
|
const auto n_ubatch = cparams.n_ubatch;
|
|
|
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
|
|
lctx.embd_seq.clear();
|
|
|
|
// count outputs
|
|
if (batch.logits && !embd_pooled) {
|
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
n_outputs += batch.logits[i] != 0;
|
|
}
|
|
} else if (lctx.logits_all || embd_pooled) {
|
|
n_outputs = n_tokens_all;
|
|
} else {
|
|
// keep last output only
|
|
n_outputs = 1;
|
|
}
|
|
|
|
lctx.sbatch.from_batch(batch, n_embd,
|
|
/* simple_split */ !kv_self.recurrent,
|
|
/* logits_all */ n_outputs == n_tokens_all);
|
|
|
|
// reserve output buffer
|
|
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
|
return -2;
|
|
};
|
|
|
|
while (lctx.sbatch.n_tokens > 0) {
|
|
llama_ubatch ubatch;
|
|
if (kv_self.recurrent) {
|
|
if (embd_pooled) {
|
|
// Pooled embeddings cannot be split across ubatches (yet)
|
|
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
|
} else {
|
|
// recurrent model architectures are easier to implement
|
|
// with equal-length sequences
|
|
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
|
}
|
|
} else {
|
|
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
|
}
|
|
const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
|
// count the outputs in this u_batch
|
|
{
|
|
int32_t n_outputs_new = 0;
|
|
|
|
if (n_outputs == n_tokens_all) {
|
|
n_outputs_new = n_tokens;
|
|
} else {
|
|
GGML_ASSERT(ubatch.output);
|
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
}
|
|
}
|
|
|
|
// needs to happen before the graph is built
|
|
lctx.n_outputs = n_outputs_new;
|
|
}
|
|
|
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
|
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
|
|
|
GGML_ASSERT(n_threads > 0);
|
|
|
|
// non-causal masks do not use the KV cache
|
|
if (hparams.causal_attn) {
|
|
llama_kv_cache_update(&lctx);
|
|
|
|
// if we have enough unused cells before the current head ->
|
|
// better to start searching from the beginning of the cache, hoping to fill it
|
|
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
|
kv_self.head = 0;
|
|
}
|
|
|
|
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
|
if (!slot) {
|
|
return 1;
|
|
}
|
|
kv_slot_restorer.save(slot);
|
|
|
|
if (!kv_self.recurrent) {
|
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
// after enough generations, the benefit from this heuristic disappears
|
|
// if we start defragmenting the cache, the benefit from this will be more important
|
|
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
|
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
|
}
|
|
}
|
|
|
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
|
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
|
|
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
|
|
|
|
// the output is always the last tensor in the graph
|
|
struct ggml_tensor * res = ggml_graph_node(gf, -1);
|
|
struct ggml_tensor * embd = ggml_graph_node(gf, -2);
|
|
|
|
if (lctx.n_outputs == 0) {
|
|
// no output
|
|
res = nullptr;
|
|
embd = nullptr;
|
|
} else if (cparams.embeddings) {
|
|
res = nullptr; // do not extract logits for embedding case
|
|
embd = nullptr;
|
|
for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
|
|
if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
|
|
embd = ggml_graph_node(gf, i);
|
|
break;
|
|
}
|
|
}
|
|
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
|
|
} else {
|
|
embd = nullptr; // do not extract embeddings when not needed
|
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
|
}
|
|
|
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
|
|
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
|
|
|
|
llama_set_inputs(lctx, ubatch);
|
|
|
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
|
kv_slot_restorer.restore(kv_self);
|
|
switch (compute_status) {
|
|
case GGML_STATUS_ABORTED:
|
|
return 2;
|
|
case GGML_STATUS_ALLOC_FAILED:
|
|
return -2;
|
|
case GGML_STATUS_FAILED:
|
|
default:
|
|
return -3;
|
|
}
|
|
}
|
|
|
|
// update the kv ring buffer
|
|
{
|
|
kv_self.head += n_tokens;
|
|
|
|
// Ensure kv cache head points to a valid index.
|
|
if (kv_self.head >= kv_self.size) {
|
|
kv_self.head = 0;
|
|
}
|
|
}
|
|
|
|
// plot the computation graph in dot format (for debugging purposes)
|
|
//if (n_past%100 == 0) {
|
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
//}
|
|
|
|
// extract logits
|
|
if (res) {
|
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), res);
|
|
GGML_ASSERT(backend_res != nullptr);
|
|
GGML_ASSERT(lctx.logits != nullptr);
|
|
|
|
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
|
const int32_t n_outputs_new = lctx.n_outputs;
|
|
|
|
if (n_outputs_new) {
|
|
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
|
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
|
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
|
}
|
|
}
|
|
|
|
// extract embeddings
|
|
if (embd) {
|
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
|
|
GGML_ASSERT(backend_embd != nullptr);
|
|
|
|
switch (cparams.pooling_type) {
|
|
case LLAMA_POOLING_TYPE_NONE:
|
|
{
|
|
// extract token embeddings
|
|
GGML_ASSERT(lctx.embd != nullptr);
|
|
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
|
const int32_t n_outputs_new = lctx.n_outputs;
|
|
|
|
if (n_outputs_new) {
|
|
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
|
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
case LLAMA_POOLING_TYPE_CLS:
|
|
case LLAMA_POOLING_TYPE_LAST:
|
|
{
|
|
// extract sequence embeddings (cleared before processing each batch)
|
|
auto & embd_seq_out = lctx.embd_seq;
|
|
|
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(n_embd);
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_RANK:
|
|
{
|
|
// extract the rerank score - a single float per sequence
|
|
auto & embd_seq_out = lctx.embd_seq;
|
|
|
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(1);
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
{
|
|
GGML_ABORT("unknown pooling type");
|
|
}
|
|
}
|
|
}
|
|
n_outputs_prev += lctx.n_outputs;
|
|
}
|
|
|
|
// set output mappings
|
|
{
|
|
bool sorted_output = true;
|
|
|
|
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
|
|
|
|
for (size_t i = 0; i < n_outputs; ++i) {
|
|
size_t out_id = lctx.sbatch.out_ids[i];
|
|
lctx.output_ids[out_id] = i;
|
|
if (out_id != i) {
|
|
sorted_output = false;
|
|
}
|
|
}
|
|
|
|
if (sorted_output) {
|
|
lctx.sbatch.out_ids.clear();
|
|
}
|
|
}
|
|
|
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
lctx.n_outputs = n_outputs;
|
|
|
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
//llama_synchronize(&lctx);
|
|
|
|
// decide if we need to defrag the kv cache
|
|
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
|
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
|
|
|
// queue defragmentation for next llama_kv_cache_update
|
|
if (fragmentation > cparams.defrag_thold) {
|
|
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
|
|
|
llama_kv_cache_defrag(kv_self);
|
|
}
|
|
}
|
|
|
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
// overlap with device computation.
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
|
|
return 0;
|
|
}
|
|
|
|
// encode a batch of tokens by evaluating the encoder part of the transformer
|
|
//
|
|
// - lctx: llama context
|
|
// - batch: batch to evaluate
|
|
//
|
|
// return 0 on success
|
|
// return positive int on warning
|
|
// return negative int on error
|
|
//
|
|
static int llama_encode_impl(
|
|
llama_context & lctx,
|
|
llama_batch inp_batch) {
|
|
|
|
lctx.is_encoding = true;
|
|
|
|
if (inp_batch.n_tokens == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
return -1;
|
|
}
|
|
|
|
// temporary allocate memory for the input batch if needed
|
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
|
|
|
const llama_batch & batch = batch_allocr.batch;
|
|
const uint32_t n_tokens = batch.n_tokens;
|
|
|
|
const auto & model = lctx.model;
|
|
const auto & hparams = model.hparams;
|
|
const auto & cparams = lctx.cparams;
|
|
|
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
|
|
if (batch.token) {
|
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
|
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
return -1;
|
|
}
|
|
}
|
|
}
|
|
|
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
|
|
if (lctx.t_compute_start_us == 0) {
|
|
lctx.t_compute_start_us = ggml_time_us();
|
|
}
|
|
|
|
lctx.n_queued_tokens += n_tokens;
|
|
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
|
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
|
|
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
|
|
|
// reserve output buffer
|
|
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
|
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
return -2;
|
|
};
|
|
|
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
lctx.output_ids[i] = i;
|
|
}
|
|
|
|
lctx.inp_embd_enc = NULL;
|
|
lctx.n_outputs = n_tokens;
|
|
|
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
|
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
|
|
|
GGML_ASSERT(n_threads > 0);
|
|
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
|
|
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
|
|
|
|
// the output embeddings after the final encoder normalization
|
|
struct ggml_tensor * embd = nullptr;
|
|
|
|
// there are two cases here
|
|
if (llama_model_has_decoder(&lctx.model)) {
|
|
// first case is an encoder-decoder T5 model where embeddings are passed to decoder
|
|
embd = ggml_graph_node(gf, -1);
|
|
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
|
|
} else {
|
|
// second case is an encoder-only T5 model
|
|
if (cparams.embeddings) {
|
|
// only output embeddings if required
|
|
embd = ggml_graph_node(gf, -1);
|
|
if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
|
embd = ggml_graph_node(gf, -2);
|
|
}
|
|
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
|
}
|
|
}
|
|
|
|
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
|
|
|
|
llama_set_inputs(lctx, ubatch);
|
|
|
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
switch (compute_status) {
|
|
case GGML_STATUS_SUCCESS:
|
|
break;
|
|
case GGML_STATUS_ABORTED:
|
|
return 2;
|
|
case GGML_STATUS_ALLOC_FAILED:
|
|
return -2;
|
|
case GGML_STATUS_FAILED:
|
|
default:
|
|
return -3;
|
|
}
|
|
|
|
// extract embeddings
|
|
if (embd) {
|
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
|
|
GGML_ASSERT(backend_embd != nullptr);
|
|
|
|
if (llama_model_has_decoder(&lctx.model)) {
|
|
lctx.embd_enc.resize(n_tokens*n_embd);
|
|
float * embd_out = lctx.embd_enc.data();
|
|
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
|
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
lctx.seq_ids_enc.resize(n_tokens);
|
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
|
lctx.seq_ids_enc[i].insert(seq_id);
|
|
}
|
|
}
|
|
} else {
|
|
GGML_ASSERT(lctx.embd != nullptr);
|
|
|
|
switch (cparams.pooling_type) {
|
|
case LLAMA_POOLING_TYPE_NONE:
|
|
{
|
|
// extract token embeddings
|
|
GGML_ASSERT(lctx.embd != nullptr);
|
|
float * embd_out = lctx.embd;
|
|
|
|
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
case LLAMA_POOLING_TYPE_CLS:
|
|
case LLAMA_POOLING_TYPE_LAST:
|
|
{
|
|
// extract sequence embeddings
|
|
auto & embd_seq_out = lctx.embd_seq;
|
|
embd_seq_out.clear();
|
|
|
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
|
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(n_embd);
|
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_RANK:
|
|
{
|
|
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
|
// wait for an encoder model that requires this pooling type in order to test it
|
|
// https://github.com/ggerganov/llama.cpp/pull/9510
|
|
GGML_ABORT("RANK pooling not implemented yet");
|
|
}
|
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
{
|
|
GGML_ABORT("unknown pooling type");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
// overlap with device computation.
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
|
|
return 0;
|
|
}
|
|
|
|
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
|
static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
|
|
auto & kv_self = lctx.kv_self;
|
|
|
|
const auto & hparams = lctx.model.hparams;
|
|
|
|
const uint32_t n_layer = hparams.n_layer;
|
|
|
|
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
|
const uint32_t n_used = kv_self.used;
|
|
|
|
assert(n_used <= n_kv);
|
|
|
|
//const int64_t t_start = ggml_time_us();
|
|
|
|
// number of cells moved
|
|
uint32_t n_moves = 0;
|
|
|
|
// each move requires 6*n_layer tensors (see build_defrag)
|
|
// - source view, destination view, copy operation
|
|
// - x2 for keys and values
|
|
//const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
|
|
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
|
const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
|
|
|
|
// determine which KV cells to move where
|
|
//
|
|
// cell i moves to ids[i]
|
|
//
|
|
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
|
//
|
|
std::vector<uint32_t> ids(n_kv, n_kv);
|
|
|
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
|
const auto & cell0 = kv_self.cells[i0];
|
|
|
|
if (!cell0.is_empty()) {
|
|
ids[i0] = i0;
|
|
|
|
continue;
|
|
}
|
|
|
|
// found a hole - fill it with data from the end of the cache
|
|
|
|
uint32_t nh = 1;
|
|
|
|
// determine the size of the hole
|
|
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
|
|
nh++;
|
|
}
|
|
|
|
uint32_t nf = 0;
|
|
uint32_t is = n_kv - 1;
|
|
|
|
// starting from the end, find nh non-empty cells
|
|
for (; is > i0; --is) {
|
|
const auto & cell1 = kv_self.cells[is];
|
|
|
|
if (cell1.is_empty() || ids[is] != n_kv) {
|
|
continue;
|
|
}
|
|
|
|
// non-empty cell which is not yet moved
|
|
nf++;
|
|
|
|
if (nf == nh) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// this can only happen if `n_used` is not accurate, which would be a bug
|
|
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
|
|
|
nf = 0;
|
|
|
|
uint32_t i1 = is;
|
|
|
|
// are we moving a continuous block of memory?
|
|
bool cont = false;
|
|
|
|
// should we stop searching for the next move?
|
|
bool stop = false;
|
|
|
|
// go back and move the nf cells to the hole
|
|
for (; i1 < n_kv; ++i1) {
|
|
auto & cell1 = kv_self.cells[i1];
|
|
|
|
if (cell1.is_empty() || ids[i1] != n_kv) {
|
|
if (n_moves == max_moves) {
|
|
stop = true;
|
|
break;
|
|
}
|
|
|
|
cont = false;
|
|
continue;
|
|
}
|
|
|
|
// this cell goes to (i0 + nf)
|
|
ids[i1] = i0 + nf;
|
|
|
|
// move the cell meta data
|
|
kv_self.cells[i0 + nf] = cell1;
|
|
|
|
// clear the old cell and move the head there
|
|
cell1 = llama_kv_cell();
|
|
kv_self.head = n_used;
|
|
|
|
if (!cont) {
|
|
n_moves++;
|
|
cont = true;
|
|
}
|
|
|
|
nf++;
|
|
|
|
if (nf == nh) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (stop || n_moves == max_moves) {
|
|
break;
|
|
}
|
|
|
|
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
|
|
|
i0 += nh - 1;
|
|
}
|
|
|
|
if (n_moves == 0) {
|
|
return;
|
|
}
|
|
|
|
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
|
|
|
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
|
|
|
#if 0
|
|
// CPU defrag
|
|
//
|
|
// TODO: optimizations are possible:
|
|
// - multiple threads
|
|
// - avoid copying to the host memory when already there
|
|
//
|
|
// likely not worth the effort, as we have ggml_graph based defrag
|
|
//
|
|
|
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
|
|
const uint32_t kv_size = kv_self.size;
|
|
|
|
std::vector<uint8_t> buf_k;
|
|
std::vector<uint8_t> buf_v;
|
|
|
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
|
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
|
|
|
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
|
const size_t v_size = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
|
|
|
|
buf_k.resize(k_size);
|
|
buf_v.resize(v_size);
|
|
|
|
ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
|
|
ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
|
|
|
|
// batch move [i, i+nm) to [id, id+nm)
|
|
// note: cells can move only to a lower index
|
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
const uint32_t id = ids[i];
|
|
|
|
if (i == id || id == n_kv) {
|
|
continue;
|
|
}
|
|
|
|
uint32_t nm = 1;
|
|
|
|
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
|
nm++;
|
|
}
|
|
|
|
// move keys
|
|
{
|
|
const int64_t os = i*k_size_row;
|
|
const int64_t od = id*k_size_row;
|
|
|
|
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
|
}
|
|
|
|
// move values (note: they are transposed)
|
|
{
|
|
const int64_t os = i;
|
|
const int64_t od = id;
|
|
|
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
|
}
|
|
}
|
|
|
|
i += nm - 1;
|
|
}
|
|
|
|
ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
|
|
ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
|
|
}
|
|
#else
|
|
// ggml_graph defrag
|
|
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
|
|
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
|
|
|
|
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
|
#endif
|
|
|
|
//const int64_t t_end = ggml_time_us();
|
|
|
|
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
|
}
|
|
|
|
static void llama_kv_cache_update_impl(struct llama_context & lctx) {
|
|
bool need_reserve = false;
|
|
|
|
if (lctx.kv_self.has_shift) {
|
|
if (!llama_kv_cache_can_shift(&lctx)) {
|
|
GGML_ABORT("The current context does not support K-shift");
|
|
}
|
|
|
|
// apply K-shift if needed
|
|
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
|
|
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
|
|
|
|
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
|
|
|
|
llama_set_k_shift(lctx);
|
|
|
|
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
|
|
|
need_reserve = true;
|
|
}
|
|
|
|
{
|
|
auto & kv_self = lctx.kv_self;
|
|
|
|
kv_self.has_shift = false;
|
|
|
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
|
kv_self.cells[i].delta = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
// defragment the KV cache if needed
|
|
if (lctx.kv_self.do_defrag) {
|
|
llama_kv_cache_defrag_impl(lctx);
|
|
|
|
need_reserve = true;
|
|
|
|
lctx.kv_self.do_defrag = false;
|
|
}
|
|
|
|
// reserve a worst case graph again
|
|
if (need_reserve) {
|
|
// TODO: extract to a function
|
|
// build worst-case graph
|
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
|
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
|
|
|
// initialize scheduler with the worst-case graph
|
|
ggml_backend_sched_reset(lctx.sched.get());
|
|
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
}
|
|
}
|
|
}
|
|
|
|
int32_t llama_lora_adapter_set(
|
|
struct llama_context * ctx,
|
|
struct llama_lora_adapter * adapter,
|
|
float scale) {
|
|
ctx->lora_adapters[adapter] = scale;
|
|
return 0;
|
|
}
|
|
|
|
int32_t llama_lora_adapter_remove(
|
|
struct llama_context * ctx,
|
|
struct llama_lora_adapter * adapter) {
|
|
auto pos = ctx->lora_adapters.find(adapter);
|
|
if (pos != ctx->lora_adapters.end()) {
|
|
ctx->lora_adapters.erase(pos);
|
|
return 0;
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
void llama_lora_adapter_clear(struct llama_context * ctx) {
|
|
ctx->lora_adapters.clear();
|
|
}
|
|
|
|
// TODO: tmp
|
|
int32_t llama_control_vector_apply(
|
|
struct llama_context * lctx,
|
|
const float * data,
|
|
size_t len,
|
|
int32_t n_embd,
|
|
int32_t il_start,
|
|
int32_t il_end) {
|
|
return llama_control_vector_apply(lctx->cvec, lctx->model, data, len, n_embd, il_start, il_end);
|
|
}
|
|
|
|
//
|
|
// interface implementation
|
|
//
|
|
|
|
struct llama_context_params llama_context_default_params() {
|
|
struct llama_context_params result = {
|
|
/*.n_ctx =*/ 512,
|
|
/*.n_batch =*/ 2048,
|
|
/*.n_ubatch =*/ 512,
|
|
/*.n_seq_max =*/ 1,
|
|
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
|
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
|
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
|
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
|
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
|
/*.rope_freq_base =*/ 0.0f,
|
|
/*.rope_freq_scale =*/ 0.0f,
|
|
/*.yarn_ext_factor =*/ -1.0f,
|
|
/*.yarn_attn_factor =*/ 1.0f,
|
|
/*.yarn_beta_fast =*/ 32.0f,
|
|
/*.yarn_beta_slow =*/ 1.0f,
|
|
/*.yarn_orig_ctx =*/ 0,
|
|
/*.defrag_thold =*/ -1.0f,
|
|
/*.cb_eval =*/ nullptr,
|
|
/*.cb_eval_user_data =*/ nullptr,
|
|
/*.type_k =*/ GGML_TYPE_F16,
|
|
/*.type_v =*/ GGML_TYPE_F16,
|
|
/*.logits_all =*/ false,
|
|
/*.embeddings =*/ false,
|
|
/*.offload_kqv =*/ true,
|
|
/*.flash_attn =*/ false,
|
|
/*.no_perf =*/ true,
|
|
/*.abort_callback =*/ nullptr,
|
|
/*.abort_callback_data =*/ nullptr,
|
|
};
|
|
|
|
return result;
|
|
}
|
|
|
|
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
|
|
struct llama_sampler_chain_params result = {
|
|
/*.no_perf =*/ true,
|
|
};
|
|
|
|
return result;
|
|
}
|
|
|
|
size_t llama_max_devices(void) {
|
|
return 16;
|
|
}
|
|
|
|
bool llama_supports_mmap(void) {
|
|
return llama_mmap::SUPPORTED;
|
|
}
|
|
|
|
bool llama_supports_mlock(void) {
|
|
return llama_mlock::SUPPORTED;
|
|
}
|
|
|
|
bool llama_supports_gpu_offload(void) {
|
|
return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
|
|
llama_supports_rpc();
|
|
}
|
|
|
|
bool llama_supports_rpc(void) {
|
|
return ggml_backend_reg_by_name("RPC") != nullptr;
|
|
}
|
|
|
|
void llama_backend_init(void) {
|
|
ggml_time_init();
|
|
|
|
// needed to initialize f16 tables
|
|
{
|
|
struct ggml_init_params params = { 0, NULL, false };
|
|
struct ggml_context * ctx = ggml_init(params);
|
|
ggml_free(ctx);
|
|
}
|
|
}
|
|
|
|
void llama_numa_init(enum ggml_numa_strategy numa) {
|
|
if (numa != GGML_NUMA_STRATEGY_DISABLED) {
|
|
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
GGML_ASSERT(dev && "CPU backend is not loaded");
|
|
auto * reg = ggml_backend_dev_backend_reg(dev);
|
|
auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init");
|
|
numa_init_fn(numa);
|
|
}
|
|
}
|
|
|
|
void llama_backend_free(void) {
|
|
ggml_quantize_free();
|
|
}
|
|
|
|
int64_t llama_time_us(void) {
|
|
return ggml_time_us();
|
|
}
|
|
|
|
struct llama_model * llama_load_model_from_file(
|
|
const char * path_model,
|
|
struct llama_model_params params) {
|
|
return llama_model_load_from_file(path_model, params);
|
|
}
|
|
|
|
struct llama_model * llama_model_load_from_file(
|
|
const char * path_model,
|
|
struct llama_model_params params) {
|
|
ggml_time_init();
|
|
|
|
llama_model * model = new llama_model;
|
|
|
|
unsigned cur_percentage = 0;
|
|
if (params.progress_callback == NULL) {
|
|
params.progress_callback_user_data = &cur_percentage;
|
|
params.progress_callback = [](float progress, void * ctx) {
|
|
unsigned * cur_percentage_p = (unsigned *) ctx;
|
|
unsigned percentage = (unsigned) (100 * progress);
|
|
while (percentage > *cur_percentage_p) {
|
|
*cur_percentage_p = percentage;
|
|
LLAMA_LOG_CONT(".");
|
|
if (percentage >= 100) {
|
|
LLAMA_LOG_CONT("\n");
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
}
|
|
|
|
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
|
|
// split the servers set them into model->rpc_servers
|
|
std::string servers(params.rpc_servers);
|
|
size_t pos = 0;
|
|
while ((pos = servers.find(',')) != std::string::npos) {
|
|
std::string server = servers.substr(0, pos);
|
|
model->rpc_servers.push_back(server);
|
|
servers.erase(0, pos + 1);
|
|
}
|
|
model->rpc_servers.push_back(servers);
|
|
}
|
|
|
|
// add RPC devices
|
|
if (!model->rpc_servers.empty()) {
|
|
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
|
if (!rpc_reg) {
|
|
LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
|
|
llama_model_free(model);
|
|
return nullptr;
|
|
}
|
|
|
|
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
|
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
|
if (!ggml_backend_rpc_add_device_fn) {
|
|
LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
|
|
llama_model_free(model);
|
|
return nullptr;
|
|
}
|
|
|
|
for (const std::string & server : model->rpc_servers) {
|
|
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
|
if (dev) {
|
|
model->devices.push_back(dev);
|
|
} else {
|
|
LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
|
|
llama_model_free(model);
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
// create list of devices to use with this model
|
|
if (params.devices) {
|
|
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
|
|
model->devices.push_back(*dev);
|
|
}
|
|
} else {
|
|
// use all available devices
|
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
switch (ggml_backend_dev_type(dev)) {
|
|
case GGML_BACKEND_DEVICE_TYPE_CPU:
|
|
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
|
// skip CPU backends since they are handled separately
|
|
break;
|
|
|
|
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
|
model->devices.push_back(dev);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// if using single GPU mode, remove all except the main GPU
|
|
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
|
|
if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
|
|
LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
|
|
llama_model_free(model);
|
|
return nullptr;
|
|
}
|
|
ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
|
|
model->devices.clear();
|
|
model->devices.push_back(main_gpu);
|
|
}
|
|
|
|
for (auto * dev : model->devices) {
|
|
size_t free, total; // NOLINT
|
|
ggml_backend_dev_memory(dev, &free, &total);
|
|
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
|
|
}
|
|
|
|
int status = llama_model_load(path_model, *model, params);
|
|
GGML_ASSERT(status <= 0);
|
|
if (status < 0) {
|
|
if (status == -1) {
|
|
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
|
|
} else if (status == -2) {
|
|
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
|
|
}
|
|
|
|
llama_model_free(model);
|
|
return nullptr;
|
|
}
|
|
|
|
return model;
|
|
}
|
|
|
|
struct llama_context * llama_new_context_with_model(
|
|
struct llama_model * model,
|
|
struct llama_context_params params) {
|
|
|
|
if (!model) {
|
|
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.n_batch == 0 && params.n_ubatch == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
|
|
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
|
params.flash_attn = false;
|
|
}
|
|
|
|
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
|
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
|
params.flash_attn = false;
|
|
}
|
|
|
|
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
llama_context * ctx = new llama_context(*model);
|
|
|
|
const auto & hparams = model->hparams;
|
|
auto & cparams = ctx->cparams;
|
|
|
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
cparams.n_threads = params.n_threads;
|
|
cparams.n_threads_batch = params.n_threads_batch;
|
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
|
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
|
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
|
cparams.defrag_thold = params.defrag_thold;
|
|
cparams.embeddings = params.embeddings;
|
|
cparams.offload_kqv = params.offload_kqv;
|
|
cparams.flash_attn = params.flash_attn;
|
|
cparams.no_perf = params.no_perf;
|
|
cparams.pooling_type = params.pooling_type;
|
|
|
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
|
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
|
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
|
|
|
// this is necessary due to kv_self.n being padded later during inference
|
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
|
|
|
|
// with causal attention, the batch size is limited by the context size
|
|
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
|
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
}
|
|
|
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
|
|
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
|
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
|
hparams.n_ctx_train;
|
|
|
|
cparams.cb_eval = params.cb_eval;
|
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
|
|
auto rope_scaling_type = params.rope_scaling_type;
|
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
}
|
|
|
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
|
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
|
}
|
|
|
|
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
}
|
|
|
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
|
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
} else {
|
|
cparams.pooling_type = hparams.pooling_type;
|
|
}
|
|
}
|
|
|
|
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
|
cparams.causal_attn = hparams.causal_attn;
|
|
} else {
|
|
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
|
}
|
|
|
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
|
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
|
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
|
|
|
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
|
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
}
|
|
|
|
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
|
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
}
|
|
|
|
ctx->logits_all = params.logits_all;
|
|
|
|
// build worst-case graph for encoder if a model contains encoder
|
|
ctx->is_encoding = llama_model_has_encoder(model);
|
|
|
|
uint32_t kv_size = cparams.n_ctx;
|
|
ggml_type type_k = params.type_k;
|
|
ggml_type type_v = params.type_v;
|
|
|
|
// Mamba only needs a constant number of KV cache cells per sequence
|
|
if (llama_model_is_recurrent(model)) {
|
|
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
|
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
|
// it's probably best to keep as much precision as possible for the states
|
|
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
|
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
|
}
|
|
|
|
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
|
|
|
if (!hparams.vocab_only) {
|
|
// GPU backends
|
|
for (auto * dev : model->devices) {
|
|
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
if (backend == nullptr) {
|
|
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
ctx->backends.emplace_back(backend);
|
|
}
|
|
|
|
// add ACCEL backends (such as BLAS)
|
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
if (backend == nullptr) {
|
|
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
ctx->backends.emplace_back(backend);
|
|
}
|
|
}
|
|
|
|
// add CPU backend
|
|
ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
if (ctx->backend_cpu == nullptr) {
|
|
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
ctx->backends.emplace_back(ctx->backend_cpu);
|
|
|
|
// create a list of the set_n_threads functions in the backends
|
|
for (auto & backend : ctx->backends) {
|
|
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
|
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
if (reg) {
|
|
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
if (ggml_backend_set_n_threads_fn) {
|
|
ctx->set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
|
|
}
|
|
}
|
|
}
|
|
|
|
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
|
|
|
|
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
|
|
{
|
|
size_t memory_size_k = 0;
|
|
size_t memory_size_v = 0;
|
|
|
|
for (auto & k : ctx->kv_self.k_l) {
|
|
memory_size_k += ggml_nbytes(k);
|
|
}
|
|
|
|
for (auto & v : ctx->kv_self.v_l) {
|
|
memory_size_v += ggml_nbytes(v);
|
|
}
|
|
|
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
|
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
|
}
|
|
|
|
// graph outputs buffer
|
|
{
|
|
// resized during inference when a batch uses more outputs
|
|
if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
|
|
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
|
|
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
|
ggml_backend_buffer_name(ctx->buf_output.get()),
|
|
ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
|
|
}
|
|
|
|
// scheduler and compute buffers
|
|
{
|
|
// buffer types used for the compute buffer of each backend
|
|
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
|
std::vector<ggml_backend_t> backend_ptrs;
|
|
for (auto & backend : ctx->backends) {
|
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
|
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
|
|
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
|
auto * dev = model->devices[0];
|
|
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
|
if (host_buft) {
|
|
buft = host_buft;
|
|
}
|
|
}
|
|
backend_buft.push_back(buft);
|
|
backend_ptrs.push_back(backend.get());
|
|
}
|
|
|
|
const size_t max_nodes = llama_model_max_nodes(*model);
|
|
|
|
// buffer used to store the computation graph and the tensor meta data
|
|
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
|
|
|
// TODO: move these checks to ggml_backend_sched
|
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
bool pipeline_parallel =
|
|
llama_get_device_count(*model) > 1 &&
|
|
model->n_gpu_layers > (int)model->hparams.n_layer &&
|
|
model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
|
params.offload_kqv;
|
|
|
|
// pipeline parallelism requires support for async compute and events in all devices
|
|
if (pipeline_parallel) {
|
|
for (auto & backend : ctx->backends) {
|
|
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
// ignore CPU backend
|
|
continue;
|
|
}
|
|
auto * dev = ggml_backend_get_device(backend.get());
|
|
ggml_backend_dev_props props;
|
|
ggml_backend_dev_get_props(dev, &props);
|
|
if (!props.caps.async || !props.caps.events) {
|
|
// device does not support async compute or events
|
|
pipeline_parallel = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
|
|
|
if (pipeline_parallel) {
|
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
|
|
}
|
|
|
|
// initialize scheduler with the worst-case graph
|
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
|
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
|
|
|
|
// reserve pp graph first so that buffers are only allocated once
|
|
ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
|
|
int n_splits_pp = ggml_backend_sched_get_n_splits(ctx->sched.get());
|
|
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
|
|
|
|
// reserve with tg graph to get the number of splits and nodes
|
|
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
|
|
ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
|
|
int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
|
|
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
|
|
|
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
|
|
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
|
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
llama_free(ctx);
|
|
return nullptr;
|
|
}
|
|
|
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
ggml_backend_t backend = backend_ptrs[i];
|
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
size_t size = ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
|
|
if (size > 1) {
|
|
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
ggml_backend_buft_name(buft),
|
|
size / 1024.0 / 1024.0);
|
|
}
|
|
}
|
|
|
|
if (n_nodes_pp == n_nodes_tg) {
|
|
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
|
} else {
|
|
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
|
}
|
|
if (n_splits_pp == n_splits_tg) {
|
|
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
|
} else {
|
|
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
}
|
|
}
|
|
}
|
|
|
|
return ctx;
|
|
}
|
|
|
|
//
|
|
// kv cache
|
|
//
|
|
|
|
// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API
|
|
|
|
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
|
|
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
|
|
}
|
|
|
|
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
|
|
llama_kv_cache_view_update(view, ctx->kv_self);
|
|
}
|
|
|
|
int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
|
return llama_get_kv_cache_token_count(ctx->kv_self);
|
|
}
|
|
|
|
int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
|
|
return llama_get_kv_cache_used_cells(ctx->kv_self);
|
|
}
|
|
|
|
void llama_kv_cache_clear(struct llama_context * ctx) {
|
|
llama_kv_cache_clear(ctx->kv_self);
|
|
}
|
|
|
|
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
|
}
|
|
|
|
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
if (seq_id_src == seq_id_dst) {
|
|
return;
|
|
}
|
|
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
|
}
|
|
|
|
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
|
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
|
|
}
|
|
|
|
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
|
if (delta == 0) {
|
|
return;
|
|
}
|
|
|
|
llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
|
|
}
|
|
|
|
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
if (d == 1) {
|
|
return;
|
|
}
|
|
|
|
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
|
}
|
|
|
|
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
|
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
|
}
|
|
|
|
void llama_kv_cache_defrag(struct llama_context * ctx) {
|
|
llama_kv_cache_defrag(ctx->kv_self);
|
|
}
|
|
|
|
void llama_kv_cache_update(struct llama_context * ctx) {
|
|
llama_kv_cache_update_impl(*ctx);
|
|
}
|
|
|
|
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
|
return llama_kv_cache_can_shift(ctx->kv_self);
|
|
}
|
|
|
|
///
|
|
|
|
int32_t llama_encode(
|
|
struct llama_context * ctx,
|
|
struct llama_batch batch) {
|
|
const int ret = llama_encode_impl(*ctx, batch);
|
|
if (ret != 0) {
|
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
int32_t llama_decode(
|
|
struct llama_context * ctx,
|
|
struct llama_batch batch) {
|
|
const int ret = llama_decode_impl(*ctx, batch);
|
|
if (ret != 0) {
|
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
//
|
|
// vocab
|
|
//
|
|
|
|
// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
|
|
|
|
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
|
return llama_token_get_text_impl(model->vocab, token);
|
|
}
|
|
|
|
float llama_token_get_score(const struct llama_model * model, llama_token token) {
|
|
return llama_token_get_score_impl(model->vocab, token);
|
|
}
|
|
|
|
enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
|
|
return llama_token_get_attr_impl(model->vocab, token);
|
|
}
|
|
|
|
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
|
|
return llama_token_is_eog_impl(model->vocab, token);
|
|
}
|
|
|
|
bool llama_token_is_control(const struct llama_model * model, llama_token token) {
|
|
return llama_token_is_control_impl(model->vocab, token);
|
|
}
|
|
|
|
llama_token llama_token_bos(const struct llama_model * model) {
|
|
return llama_token_bos_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_eos(const struct llama_model * model) {
|
|
return llama_token_eos_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_eot(const struct llama_model * model) {
|
|
return llama_token_eot_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_cls(const struct llama_model * model) {
|
|
return llama_token_cls_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_sep(const struct llama_model * model) {
|
|
return llama_token_sep_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_nl (const struct llama_model * model) {
|
|
return llama_token_nl_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_pad(const struct llama_model * model) {
|
|
return llama_token_pad_impl(model->vocab);
|
|
}
|
|
|
|
bool llama_add_bos_token(const struct llama_model * model) {
|
|
return llama_add_bos_token_impl(model->vocab);
|
|
}
|
|
|
|
bool llama_add_eos_token(const struct llama_model * model) {
|
|
return llama_add_eos_token_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_prefix(const struct llama_model * model) {
|
|
return llama_token_prefix_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_middle(const struct llama_model * model) {
|
|
return llama_token_middle_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_suffix(const struct llama_model * model) {
|
|
return llama_token_suffix_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_pre(const struct llama_model * model) {
|
|
return llama_token_fim_pre_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_suf(const struct llama_model * model) {
|
|
return llama_token_fim_suf_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_mid(const struct llama_model * model) {
|
|
return llama_token_fim_mid_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_pad(const struct llama_model * model) {
|
|
return llama_token_fim_pad_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_rep(const struct llama_model * model) {
|
|
return llama_token_fim_rep_impl(model->vocab);
|
|
}
|
|
|
|
llama_token llama_token_fim_sep(const struct llama_model * model) {
|
|
return llama_token_fim_sep_impl(model->vocab);
|
|
}
|
|
|
|
//
|
|
// tokenization
|
|
//
|
|
|
|
int32_t llama_tokenize(
|
|
const struct llama_model * model,
|
|
const char * text,
|
|
int32_t text_len,
|
|
llama_token * tokens,
|
|
int32_t n_tokens_max,
|
|
bool add_special,
|
|
bool parse_special) {
|
|
return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
|
|
}
|
|
|
|
int32_t llama_token_to_piece(
|
|
const struct llama_model * model,
|
|
llama_token token,
|
|
char * buf,
|
|
int32_t length,
|
|
int32_t lstrip,
|
|
bool special) {
|
|
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
|
|
}
|
|
|
|
int32_t llama_detokenize(
|
|
const struct llama_model * model,
|
|
const llama_token * tokens,
|
|
int32_t n_tokens,
|
|
char * text,
|
|
int32_t text_len_max,
|
|
bool remove_special,
|
|
bool unparse_special) {
|
|
return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
|
}
|
|
|
|
//
|
|
// chat templates
|
|
//
|
|
|
|
int32_t llama_chat_apply_template(
|
|
const struct llama_model * model,
|
|
const char * tmpl,
|
|
const struct llama_chat_message * chat,
|
|
size_t n_msg,
|
|
bool add_ass,
|
|
char * buf,
|
|
int32_t length) {
|
|
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
|
if (tmpl == nullptr) {
|
|
GGML_ASSERT(model != nullptr);
|
|
|
|
// load template from model, if available
|
|
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
|
|
if (it != model->gguf_kv.end() && it->second.size() > 0) {
|
|
curr_tmpl = it->second;
|
|
}
|
|
else {
|
|
// worst case: there is no information about template, we will use chatml by default
|
|
curr_tmpl = "chatml"; // see llm_chat_apply_template
|
|
}
|
|
}
|
|
|
|
// format the chat to string
|
|
std::vector<const llama_chat_message *> chat_vec;
|
|
chat_vec.resize(n_msg);
|
|
for (size_t i = 0; i < n_msg; i++) {
|
|
chat_vec[i] = &chat[i];
|
|
}
|
|
|
|
std::string formatted_chat;
|
|
llm_chat_template detected_tmpl = llm_chat_detect_template(curr_tmpl);
|
|
if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) {
|
|
return -1;
|
|
}
|
|
int32_t res = llm_chat_apply_template(detected_tmpl, chat_vec, formatted_chat, add_ass);
|
|
if (res < 0) {
|
|
return res;
|
|
}
|
|
if (buf && length > 0) {
|
|
strncpy(buf, formatted_chat.c_str(), length);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
//
|
|
// sampling
|
|
//
|
|
|
|
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
|
|
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
|
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
|
}
|
|
|
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
|
|
return llama_sampler_init_infill_impl(model->vocab);
|
|
}
|
|
|
|
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
|
return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
|
|
}
|
|
|
|
//
|
|
// model split
|
|
//
|
|
|
|
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
|
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
|
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
|
return strlen(split_path);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
|
std::string str_split_path(split_path);
|
|
char postfix[32];
|
|
snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
|
|
std::string str_postfix(postfix);
|
|
|
|
// check if split_prefix ends with postfix
|
|
int size_prefix = str_split_path.size() - str_postfix.size();
|
|
if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
|
|
snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
|
|
return size_prefix;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
const char * llama_print_system_info(void) {
|
|
static std::string s;
|
|
s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
|
|
|
|
|
|
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
|
auto * reg = ggml_backend_reg_get(i);
|
|
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
|
if (get_features_fn) {
|
|
ggml_backend_feature * features = get_features_fn(reg);
|
|
s += ggml_backend_reg_name(reg);
|
|
s += " : ";
|
|
for (; features->name; features++) {
|
|
s += features->name;
|
|
s += " = ";
|
|
s += features->value;
|
|
s += " | ";
|
|
}
|
|
}
|
|
}
|
|
|
|
return s.c_str();
|
|
}
|
|
|
|
//
|
|
// perf
|
|
//
|
|
|
|
struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
|
|
struct llama_perf_context_data data = {};
|
|
|
|
if (ctx == nullptr) {
|
|
return data;
|
|
}
|
|
|
|
data.t_start_ms = 1e-3 * ctx->t_start_us;
|
|
data.t_load_ms = 1e-3 * ctx->t_load_us;
|
|
data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
|
|
data.t_eval_ms = 1e-3 * ctx->t_eval_us;
|
|
data.n_p_eval = std::max(1, ctx->n_p_eval);
|
|
data.n_eval = std::max(1, ctx->n_eval);
|
|
|
|
return data;
|
|
}
|
|
|
|
void llama_perf_context_print(const struct llama_context * ctx) {
|
|
const auto data = llama_perf_context(ctx);
|
|
|
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
|
|
|
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
|
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
|
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
|
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
|
}
|
|
|
|
void llama_perf_context_reset(struct llama_context * ctx) {
|
|
ctx->t_start_us = ggml_time_us();
|
|
ctx->t_eval_us = ctx->n_eval = 0;
|
|
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
|
}
|