mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 00:20:34 +01:00
temporary refactor llama_vision_graph_builder
This commit is contained in:
parent
32daa38333
commit
9716c7bff7
@ -50,7 +50,7 @@ static llama_vision_bitmap * load_image_from_file(const char * fname) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// split string by a `std::string delim` instead of `char delim`
|
// split string by a `std::string delim` instead of `char delim`
|
||||||
static std::vector<std::string> string_split(std::string s, const std::string & delimiter) {
|
static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
|
||||||
std::vector<std::string> tokens;
|
std::vector<std::string> tokens;
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
std::string token;
|
std::string token;
|
||||||
@ -76,7 +76,7 @@ static std::vector<tokenized_part> tokenize_with_img_placement(
|
|||||||
const std::string & text,
|
const std::string & text,
|
||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special) {
|
bool parse_special) {
|
||||||
std::vector<std::string> parts = string_split(text, IMG_PLACEMENT);
|
std::vector<std::string> parts = string_split_str(text, IMG_PLACEMENT);
|
||||||
std::vector<tokenized_part> output;
|
std::vector<tokenized_part> output;
|
||||||
for (const auto & part : parts) {
|
for (const auto & part : parts) {
|
||||||
//printf("tokenizing part: %s\n", part.c_str());
|
//printf("tokenizing part: %s\n", part.c_str());
|
||||||
@ -114,6 +114,10 @@ int main(int argc, char ** argv) {
|
|||||||
llama_context * ctx = llama_init.context.get();
|
llama_context * ctx = llama_init.context.get();
|
||||||
const llama_model * model = llama_init.model.get();
|
const llama_model * model = llama_init.model.get();
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
if (!model) {
|
||||||
|
LOG_ERR("failed to load model\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||||
|
|
||||||
|
@ -4056,6 +4056,11 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
|||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
return LLAMA_ROPE_TYPE_MROPE;
|
return LLAMA_ROPE_TYPE_MROPE;
|
||||||
|
|
||||||
|
case LLM_ARCH_VISION_LLAVA:
|
||||||
|
case LLM_ARCH_VISION_MOBILEVLM:
|
||||||
|
case LLM_ARCH_VISION_MINICPMV:
|
||||||
|
GGML_ABORT("vision arch does not use RoPE");
|
||||||
|
|
||||||
// all model arches should be listed explicitly here
|
// all model arches should be listed explicitly here
|
||||||
case LLM_ARCH_UNKNOWN:
|
case LLM_ARCH_UNKNOWN:
|
||||||
GGML_ABORT("unknown architecture");
|
GGML_ABORT("unknown architecture");
|
||||||
|
@ -19,6 +19,8 @@ struct img_size;
|
|||||||
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
|
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define VISION_GRAPH_MAX_NODE 1024
|
||||||
|
|
||||||
struct img_size {
|
struct img_size {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
@ -403,7 +405,7 @@ struct llama_vision_processor_llava : llama_vision_processor {
|
|||||||
output_slices.buf[0] = std::move(res);
|
output_slices.buf[0] = std::move(res);
|
||||||
|
|
||||||
return output_slices;
|
return output_slices;
|
||||||
};
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_vision_processor_uhd : llama_vision_processor {
|
struct llama_vision_processor_uhd : llama_vision_processor {
|
||||||
@ -572,33 +574,56 @@ struct llama_vision_processor_uhd : llama_vision_processor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int batch_size, img_size & image_size) {
|
// TODO: move this to llm_build_context in llama.cpp
|
||||||
auto & model = *ctx.model;
|
struct llama_vision_graph_builder {
|
||||||
auto & hparams = ctx.model->hparams;
|
llama_vision_context & ctx;
|
||||||
|
const llama_vision_model & model;
|
||||||
|
struct ggml_context * ctx0;
|
||||||
|
int batch_size;
|
||||||
|
int hidden_size;
|
||||||
|
int n_head;
|
||||||
|
int d_head;
|
||||||
|
int patch_size;
|
||||||
|
float eps;
|
||||||
|
int num_patches;
|
||||||
|
int num_positions;
|
||||||
|
int img_w;
|
||||||
|
int img_h;
|
||||||
|
bool use_gelu;
|
||||||
|
int n_layers;
|
||||||
|
vision_projector_type proj_type;
|
||||||
|
|
||||||
const int hidden_size = hparams.hidden_size;
|
llama_vision_graph_builder(llama_vision_context & ctx, const llama_vision_tokens & inp) : ctx(ctx), model(*ctx.model) {
|
||||||
const int n_head = hparams.n_head;
|
struct ggml_init_params params = {
|
||||||
const int d_head = hidden_size / n_head;
|
/*.mem_size =*/ ctx.buf_compute_meta.size(),
|
||||||
const int patch_size = hparams.patch_size;
|
/*.mem_buffer =*/ ctx.buf_compute_meta.data(),
|
||||||
const float eps = hparams.eps;
|
/*.no_alloc =*/ true,
|
||||||
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
|
};
|
||||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
ctx0 = ggml_init(params);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches);
|
auto & hparams = ctx.model->hparams;
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
batch_size = inp.buf.size();
|
||||||
/*.mem_size =*/ ctx.buf_compute_meta.size(),
|
hidden_size = hparams.hidden_size;
|
||||||
/*.mem_buffer =*/ ctx.buf_compute_meta.data(),
|
n_head = hparams.n_head;
|
||||||
/*.no_alloc =*/ true,
|
d_head = hidden_size / n_head;
|
||||||
};
|
patch_size = hparams.patch_size;
|
||||||
|
eps = hparams.eps;
|
||||||
|
num_patches = inp.n_px * inp.n_py;
|
||||||
|
num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
|
img_w = inp.px * inp.n_px;
|
||||||
|
img_h = inp.py * inp.n_py;
|
||||||
|
use_gelu = hparams.use_gelu;
|
||||||
|
n_layers = (int)hparams.n_layer + hparams.select_layer;
|
||||||
|
proj_type = hparams.proj_type;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
~llama_vision_graph_builder() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_free(ctx0);
|
||||||
|
}
|
||||||
|
|
||||||
// input
|
struct ggml_tensor * build_inp() {
|
||||||
struct ggml_tensor * embeddings;
|
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img_w, img_h, 3, batch_size);
|
||||||
{
|
|
||||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size.width, image_size.height, 3, batch_size);
|
|
||||||
ggml_set_name(inp_raw, "inp_raw");
|
ggml_set_name(inp_raw, "inp_raw");
|
||||||
ggml_set_input(inp_raw);
|
ggml_set_input(inp_raw);
|
||||||
|
|
||||||
@ -612,37 +637,51 @@ static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int ba
|
|||||||
}
|
}
|
||||||
// auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]);
|
// auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]);
|
||||||
|
|
||||||
embeddings = inp;
|
struct ggml_tensor * embd = inp;
|
||||||
if (model.class_embedding) {
|
if (model.class_embedding) {
|
||||||
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
embd = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||||
ggml_set_name(embeddings, "embeddings");
|
ggml_set_name(embd, "inp_embd");
|
||||||
ggml_set_input(embeddings);
|
ggml_set_input(embd);
|
||||||
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
|
||||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
embd = ggml_acc(ctx0, embd, model.class_embedding,
|
||||||
embeddings = ggml_acc(ctx0, embeddings, inp,
|
embd->nb[1], embd->nb[2], embd->nb[3], 0);
|
||||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
embd = ggml_acc(ctx0, embd, inp,
|
||||||
|
embd->nb[1], embd->nb[2], embd->nb[3], model.class_embedding->nb[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||||
ggml_set_name(positions, "positions");
|
ggml_set_name(positions, "inp_pos");
|
||||||
ggml_set_input(positions);
|
ggml_set_input(positions);
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0,
|
embd = ggml_add(ctx0,
|
||||||
embeddings,
|
embd,
|
||||||
ggml_get_rows(ctx0, model.position_embeddings, positions));
|
ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||||
|
|
||||||
|
return embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-layernorm
|
struct ggml_tensor * build_pre_norm(struct ggml_tensor * cur) {
|
||||||
if (model.pre_norm_w) {
|
if (model.pre_norm_w) {
|
||||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
cur = ggml_norm(ctx0, cur, eps);
|
||||||
ggml_set_name(embeddings, "pre_ln");
|
ggml_set_name(cur, "pre_ln");
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b);
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.pre_norm_w), model.pre_norm_b);
|
||||||
|
}
|
||||||
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
struct ggml_tensor * build_post_norm(struct ggml_tensor * cur) {
|
||||||
for (int il = 0; il < (int)hparams.n_layer + hparams.select_layer; il++) {
|
if (model.post_norm_w) {
|
||||||
struct ggml_tensor * cur = embeddings;
|
cur = ggml_norm(ctx0, cur, eps);
|
||||||
|
ggml_set_name(cur, "post_ln");
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.post_norm_w), model.post_norm_b);
|
||||||
|
}
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_layer(struct ggml_tensor * inpL, int il) {
|
||||||
|
struct ggml_tensor * cur = inpL;
|
||||||
|
|
||||||
// layernorm1
|
// layernorm1
|
||||||
{
|
{
|
||||||
@ -654,7 +693,6 @@ static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int ba
|
|||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
|
|
||||||
struct ggml_tensor * Q = ggml_add(ctx0,
|
struct ggml_tensor * Q = ggml_add(ctx0,
|
||||||
ggml_mul_mat(ctx0, model.layers[il].q_w, cur),
|
ggml_mul_mat(ctx0, model.layers[il].q_w, cur),
|
||||||
model.layers[il].q_b);
|
model.layers[il].q_b);
|
||||||
@ -693,9 +731,9 @@ static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int ba
|
|||||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b);
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b);
|
||||||
|
|
||||||
// re-add the layer input, e.g., residual
|
// re-add the layer input, e.g., residual
|
||||||
cur = ggml_add(ctx0, cur, embeddings);
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
inpL = cur; // inpL = residual, cur = hidden_states
|
||||||
|
|
||||||
// layernorm2
|
// layernorm2
|
||||||
{
|
{
|
||||||
@ -708,7 +746,7 @@ static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int ba
|
|||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b);
|
||||||
|
|
||||||
if (hparams.use_gelu) {
|
if (use_gelu) {
|
||||||
cur = ggml_gelu_inplace(ctx0, cur);
|
cur = ggml_gelu_inplace(ctx0, cur);
|
||||||
} else {
|
} else {
|
||||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||||
@ -718,74 +756,76 @@ static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int ba
|
|||||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b);
|
||||||
|
|
||||||
// residual 2
|
// residual 2
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, inpL, cur);
|
||||||
|
|
||||||
embeddings = cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// post-layernorm
|
// graph for each vision arch
|
||||||
if (model.post_norm_w) {
|
|
||||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
||||||
ggml_set_name(embeddings, "post_ln");
|
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b);
|
struct ggml_cgraph * build_llava() {
|
||||||
}
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false);
|
||||||
|
struct ggml_tensor * cur = build_inp();
|
||||||
// llava projector
|
cur = build_pre_norm(cur);
|
||||||
{
|
for (int il = 0; il < n_layers; il++) {
|
||||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
cur = build_layer(cur, il);
|
||||||
|
|
||||||
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
|
||||||
ggml_set_name(patches, "patches");
|
|
||||||
ggml_set_input(patches);
|
|
||||||
|
|
||||||
// shape [1, 576, 1024]
|
|
||||||
// ne is whcn, ne = [1024, 576, 1, 1]
|
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
|
||||||
|
|
||||||
if (hparams.proj_type == VISION_PROJECTOR_TYPE_MLP) {
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
|
||||||
|
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
|
||||||
|
|
||||||
} else if (hparams.proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
|
||||||
int n_patch = 24;
|
|
||||||
struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
|
||||||
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
|
||||||
mlp_0 = ggml_gelu(ctx0, mlp_0);
|
|
||||||
struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
|
|
||||||
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
|
|
||||||
// mlp_2 ne = [2048, 576, 1, 1]
|
|
||||||
// // AVG Pool Layer 2*2, strides = 2
|
|
||||||
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
|
|
||||||
// mlp_2 ne = [576, 2048, 1, 1]
|
|
||||||
mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
|
|
||||||
// mlp_2 ne [24, 24, 2048, 1]
|
|
||||||
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
|
|
||||||
// weight ne = [3, 3, 2048, 1]
|
|
||||||
struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
|
||||||
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
|
|
||||||
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
|
|
||||||
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
|
|
||||||
peg_0 = ggml_add(ctx0, peg_0, mlp_2);
|
|
||||||
peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
|
|
||||||
embeddings = peg_0;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "unsupported proj type");
|
|
||||||
}
|
}
|
||||||
|
cur = build_post_norm(cur);
|
||||||
|
|
||||||
|
// llava projector
|
||||||
|
{
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1]);
|
||||||
|
|
||||||
|
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||||
|
ggml_set_name(patches, "inp_patches");
|
||||||
|
ggml_set_input(patches);
|
||||||
|
|
||||||
|
// shape [1, 576, 1024]
|
||||||
|
// ne is whcn, ne = [1024, 576, 1, 1]
|
||||||
|
cur = ggml_get_rows(ctx0, cur, patches);
|
||||||
|
|
||||||
|
if (proj_type == VISION_PROJECTOR_TYPE_MLP) {
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_1_b);
|
||||||
|
|
||||||
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_2_b);
|
||||||
|
|
||||||
|
} else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
||||||
|
int n_patch = 24;
|
||||||
|
struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, cur);
|
||||||
|
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
||||||
|
mlp_0 = ggml_gelu(ctx0, mlp_0);
|
||||||
|
struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
|
||||||
|
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
|
||||||
|
// mlp_2 ne = [2048, 576, 1, 1]
|
||||||
|
// // AVG Pool Layer 2*2, strides = 2
|
||||||
|
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
|
||||||
|
// mlp_2 ne = [576, 2048, 1, 1]
|
||||||
|
mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
|
||||||
|
// mlp_2 ne [24, 24, 2048, 1]
|
||||||
|
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
|
||||||
|
// weight ne = [3, 3, 2048, 1]
|
||||||
|
struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
||||||
|
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
|
||||||
|
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
|
||||||
|
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
|
||||||
|
peg_0 = ggml_add(ctx0, peg_0, mlp_2);
|
||||||
|
peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
|
||||||
|
cur = ggml_cont(ctx0, peg_0);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "unsupported proj type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_name(cur, "output");
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
embeddings = ggml_cont(ctx0, embeddings);
|
|
||||||
|
|
||||||
// build the graph
|
|
||||||
ggml_build_forward_expand(gf, embeddings);
|
|
||||||
ggml_free(ctx0);
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) {
|
static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) {
|
||||||
int batch_size = inp.buf.size();
|
int batch_size = inp.buf.size();
|
||||||
@ -805,7 +845,16 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||||||
LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions);
|
LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions);
|
||||||
|
|
||||||
// build the inference graph
|
// build the inference graph
|
||||||
ggml_cgraph * gf = llama_vision_build_graph(ctx, batch_size, image_size);
|
llama_vision_graph_builder builder(ctx, inp);
|
||||||
|
ggml_cgraph * gf;
|
||||||
|
switch(hparams.arch) {
|
||||||
|
case LLM_ARCH_VISION_LLAVA:
|
||||||
|
case LLM_ARCH_VISION_MOBILEVLM:
|
||||||
|
gf = builder.build_llava();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "unsupported arch");
|
||||||
|
}
|
||||||
|
|
||||||
// alloc memory for graph
|
// alloc memory for graph
|
||||||
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
|
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
|
||||||
@ -839,16 +888,12 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model.class_embedding) {
|
if (model.class_embedding) {
|
||||||
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
|
struct ggml_tensor * inp_embd = ggml_graph_get_tensor(gf, "inp_embd");
|
||||||
|
ggml_set_zero(inp_embd);
|
||||||
void* zero_mem = malloc(ggml_nbytes(embeddings));
|
|
||||||
memset(zero_mem, 0, ggml_nbytes(embeddings));
|
|
||||||
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
|
|
||||||
free(zero_mem);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos");
|
||||||
|
|
||||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||||
for (int i = 0; i < num_positions; i++) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
@ -859,7 +904,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "inp_patches");
|
||||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
for (int i = 0; i < num_patches; i++) {
|
for (int i = 0; i < num_patches; i++) {
|
||||||
patches_data[i] = i + 1;
|
patches_data[i] = i + 1;
|
||||||
|
Loading…
Reference in New Issue
Block a user