mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
rename everywhere
This commit is contained in:
parent
bd0714b977
commit
ad38e87329
@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
|
||||
int n_prompt = 0;
|
||||
|
||||
// process image
|
||||
llama_vision_patches * img_patches = nullptr;
|
||||
llama_vision_tokens * img_tokens = nullptr;
|
||||
{
|
||||
const char * img_path = params.image[0].c_str();
|
||||
if (params.image[0].empty()) {
|
||||
@ -131,12 +131,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
llama_vision_bitmap * img = load_image_from_file(img_path);
|
||||
LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny);
|
||||
img_patches = llama_vision_patches_init(ctx, img);
|
||||
if (!img_patches) {
|
||||
LOG_ERR("failed to create image patches\n");
|
||||
img_tokens = llama_vision_tokenize(ctx, img);
|
||||
if (!img_tokens) {
|
||||
LOG_ERR("failed to create image tokens\n");
|
||||
return 1;
|
||||
}
|
||||
if (llama_vision_encode(ctx, img_patches)) {
|
||||
if (llama_vision_encode(ctx, img_tokens)) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
return 1;
|
||||
}
|
||||
|
@ -229,7 +229,9 @@ extern "C" {
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
struct llama_vision_patches;
|
||||
// Structure represents the basic input unit of vision model
|
||||
// This can be a processed image or slices of images under the hood
|
||||
struct llama_vision_tokens;
|
||||
|
||||
// represent an RGB image
|
||||
// size of data must be equal to 3*nx*ny
|
||||
@ -1286,12 +1288,15 @@ extern "C" {
|
||||
LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny);
|
||||
LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp);
|
||||
|
||||
// Create patches from the RGB bitmap
|
||||
LLAMA_API struct llama_vision_patches * llama_vision_patches_init(struct llama_context * ctx, llama_vision_bitmap * bmp);
|
||||
LLAMA_API void llama_vision_patches_free(struct llama_vision_patches * p);
|
||||
// Create image tokens from the RGB bitmap
|
||||
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp);
|
||||
LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens);
|
||||
|
||||
// User must reserve N number of tokens in tokenized text prompt for each image
|
||||
// LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens);
|
||||
|
||||
// Encode patches into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p);
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
|
@ -110,7 +110,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
// vision
|
||||
clip_context vctx;
|
||||
llama_vision_context vctx;
|
||||
};
|
||||
|
||||
// TODO: make these methods of llama_context
|
||||
|
@ -1268,8 +1268,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
{
|
||||
std::string name;
|
||||
ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true);
|
||||
vparams.proj_type = clip_projector_type_from_name(name);
|
||||
if (vparams.proj_type == CLIP_PROJECTOR_TYPE_UNKNOWN) {
|
||||
vparams.proj_type = vision_projector_type_from_name(name);
|
||||
if (vparams.proj_type == VISION_PROJECTOR_TYPE_UNKNOWN) {
|
||||
throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str()));
|
||||
}
|
||||
}
|
||||
@ -3514,7 +3514,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
throw std::runtime_error("unknown vision architecture");
|
||||
}
|
||||
|
||||
if (clip_n_mmproj_embd(clip) != hparams.n_embd) {
|
||||
if (llama_vision_n_mmproj_embd(clip) != hparams.n_embd) {
|
||||
std::runtime_error("model has vision, but n_mmproj_embd != n_embd");
|
||||
}
|
||||
}
|
||||
|
@ -365,7 +365,7 @@ struct llama_model {
|
||||
|
||||
// vision
|
||||
bool has_vision = false;
|
||||
clip_vision_model clip;
|
||||
llama_vision_model clip;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
@ -13,25 +13,25 @@
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
// export clip_image_u8 to bmp file for debugging
|
||||
// export llama_image_u8 to bmp file for debugging
|
||||
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
|
||||
struct clip_image_size;
|
||||
static int bmp_export(const struct clip_image_u8 &img, const std::string &location);
|
||||
struct img_size;
|
||||
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
|
||||
#endif
|
||||
|
||||
struct clip_image_size {
|
||||
struct img_size {
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
// RGB uint8 image
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_u8 {
|
||||
struct llama_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
std::vector<uint8_t> buf;
|
||||
clip_image_u8() {}
|
||||
clip_image_u8(const llama_vision_bitmap & bmp) {
|
||||
llama_image_u8() {}
|
||||
llama_image_u8(const llama_vision_bitmap & bmp) {
|
||||
nx = bmp.nx;
|
||||
ny = bmp.ny;
|
||||
buf.resize(nx*ny*3);
|
||||
@ -39,39 +39,43 @@ struct clip_image_u8 {
|
||||
}
|
||||
};
|
||||
|
||||
struct clip_image_u8_batch {
|
||||
struct clip_image_u8 * data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static int clip_n_patches_x(const clip_context & ctx) {
|
||||
auto & hparams = ctx.model->hparams;
|
||||
return hparams.image_size / hparams.patch_size;
|
||||
}
|
||||
|
||||
static int clip_n_patches_y(const clip_context & ctx) {
|
||||
return clip_n_patches_x(ctx);
|
||||
}
|
||||
|
||||
static int clip_n_patches(const clip_context & ctx) {
|
||||
return clip_n_patches_x(ctx) * clip_n_patches_y(ctx);
|
||||
}
|
||||
|
||||
uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) {
|
||||
auto & proj_type = clip_model.hparams.proj_type;
|
||||
if (proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
return clip_model.mm_2_b->ne[0];
|
||||
} else if (proj_type == CLIP_PROJECTOR_TYPE_LDPV2) {
|
||||
return clip_model.mm_model_peg_0_b->ne[0];
|
||||
} else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_5) {
|
||||
uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) {
|
||||
auto & proj_type = vmodel.hparams.proj_type;
|
||||
if (proj_type == VISION_PROJECTOR_TYPE_MLP) {
|
||||
return vmodel.mm_2_b->ne[0];
|
||||
} else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
||||
return vmodel.mm_model_peg_0_b->ne[0];
|
||||
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) {
|
||||
return 4096;
|
||||
} else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_6) {
|
||||
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) {
|
||||
return 3584;
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid proj type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// internal utils
|
||||
//
|
||||
|
||||
static int get_n_patches_x(const llama_vision_context & ctx) {
|
||||
auto & hparams = ctx.model->hparams;
|
||||
return hparams.image_size / hparams.patch_size;
|
||||
}
|
||||
|
||||
static int get_n_patches_y(const llama_vision_context & ctx) {
|
||||
return get_n_patches_x(ctx);
|
||||
}
|
||||
|
||||
static int get_n_patches(const llama_vision_context & ctx) {
|
||||
return get_n_patches_x(ctx) * get_n_patches_y(ctx);
|
||||
}
|
||||
|
||||
//
|
||||
// bitmap utils
|
||||
//
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
*
|
||||
@ -79,11 +83,11 @@ uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) {
|
||||
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
* @return The best fit resolution in the format (width, height).
|
||||
*/
|
||||
static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector<clip_image_size>& possible_resolutions) {
|
||||
static img_size select_best_resolution(const img_size & original_size, const std::vector<img_size>& possible_resolutions) {
|
||||
int original_width = original_size.width;
|
||||
int original_height = original_size.height;
|
||||
|
||||
clip_image_size best_fit;
|
||||
img_size best_fit;
|
||||
int max_effective_resolution = 0;
|
||||
int min_wasted_resolution = std::numeric_limits<int>::max();
|
||||
|
||||
@ -106,7 +110,7 @@ static clip_image_size select_best_resolution(const clip_image_size & original_s
|
||||
return best_fit;
|
||||
}
|
||||
|
||||
static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
static bool bicubic_resize(const llama_image_u8 & img, llama_image_u8 & dst, int target_width, int target_height) {
|
||||
auto clip = [](int x, int lower, int upper) -> int {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
};
|
||||
@ -173,13 +177,13 @@ static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int t
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::vector<clip_image_u8> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
|
||||
std::vector<clip_image_u8> patches;
|
||||
static std::vector<llama_image_u8> divide_to_patches_u8(const llama_image_u8 & image, int patch_size) {
|
||||
std::vector<llama_image_u8> patches;
|
||||
int width = image.nx;
|
||||
int height = image.ny;
|
||||
for (int i = 0; i < height; i += patch_size) {
|
||||
for (int j = 0; j < width; j += patch_size) {
|
||||
clip_image_u8 patch;
|
||||
llama_image_u8 patch;
|
||||
patch.nx = std::min(patch_size, width - j);
|
||||
patch.ny = std::min(patch_size, height - i);
|
||||
patch.buf.resize(3 * patch.nx * patch.ny);
|
||||
@ -197,7 +201,7 @@ static std::vector<clip_image_u8> divide_to_patches_u8(const clip_image_u8 & ima
|
||||
}
|
||||
|
||||
// llava-1.6 type of resize_and_pad (black)
|
||||
static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const clip_image_size & target_resolution) {
|
||||
static llama_image_u8 resize_and_pad_image(const llama_image_u8 & image, const img_size & target_resolution) {
|
||||
int target_width = target_resolution.width;
|
||||
int target_height = target_resolution.height;
|
||||
|
||||
@ -214,11 +218,11 @@ static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const cli
|
||||
new_width = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
|
||||
}
|
||||
|
||||
clip_image_u8 resized_image;
|
||||
llama_image_u8 resized_image;
|
||||
// bilinear_resize(image, resized_image, new_width, new_height);
|
||||
bicubic_resize(image, resized_image, new_width, new_height);
|
||||
|
||||
clip_image_u8 padded_image;
|
||||
llama_image_u8 padded_image;
|
||||
padded_image.nx = target_width;
|
||||
padded_image.ny = target_height;
|
||||
padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black
|
||||
@ -238,7 +242,7 @@ static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const cli
|
||||
return padded_image;
|
||||
}
|
||||
|
||||
static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vector<float> & dst, const std::array<float, 3> & mean, const std::array<float, 3> & std) {
|
||||
static void normalize_image_u8_to_f32(const llama_image_u8 & src, std::vector<float> & dst, const std::array<float, 3> & mean, const std::array<float, 3> & std) {
|
||||
dst.resize(src.buf.size());
|
||||
|
||||
for (size_t i = 0; i < src.buf.size(); ++i) {
|
||||
@ -247,15 +251,169 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vector<flo
|
||||
}
|
||||
}
|
||||
|
||||
#define LLAMA_LOG_DEBUG LLAMA_LOG_INFO
|
||||
|
||||
// minicpmv preprocessor
|
||||
struct minicpmv_preprocessor {
|
||||
//
|
||||
// processor
|
||||
//
|
||||
|
||||
struct llama_vision_processor {
|
||||
const llama_vision_context & ctx;
|
||||
llama_vision_processor(const llama_vision_context & ctx) : ctx(ctx) {}
|
||||
virtual llama_vision_tokens tokenize(const llama_image_u8 & img) = 0;
|
||||
virtual ~llama_vision_processor() = default;
|
||||
};
|
||||
|
||||
// inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
|
||||
struct llama_vision_processor_llava : llama_vision_processor {
|
||||
llama_vision_processor_llava(const llama_vision_context & ctx) : llama_vision_processor(ctx) {}
|
||||
|
||||
virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override {
|
||||
bool pad_to_square = true;
|
||||
auto & params = ctx.model->hparams;
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) {
|
||||
pad_to_square = false;
|
||||
}
|
||||
|
||||
llama_vision_tokens output_slices;
|
||||
output_slices.n_px = get_n_patches_x(ctx);
|
||||
output_slices.n_py = get_n_patches_y(ctx);
|
||||
output_slices.px = params.patch_size;
|
||||
output_slices.py = params.patch_size;
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
||||
llama_image_u8 temp;
|
||||
if (pad_to_square && img.nx != img.ny) {
|
||||
// if the image is not square, pad it to a square
|
||||
int longer_side = std::max(img.nx, img.ny);
|
||||
temp.nx = longer_side;
|
||||
temp.ny = longer_side;
|
||||
temp.buf.resize(3 * longer_side * longer_side);
|
||||
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
|
||||
|
||||
// fill with background color
|
||||
for (size_t i = 0; i < temp.buf.size(); i++) {
|
||||
temp.buf[i] = bc[i % 3];
|
||||
}
|
||||
|
||||
// copy from the input image
|
||||
for (int y = 0; y < img.ny; y++) {
|
||||
for (int x = 0; x < img.nx; x++) {
|
||||
const int i = 3 * (y * img.nx + x);
|
||||
const int j = 3 * (y * temp.nx + x);
|
||||
temp.buf[j] = img.buf[i];
|
||||
temp.buf[j+1] = img.buf[i+1];
|
||||
temp.buf[j+2] = img.buf[i+2];
|
||||
}
|
||||
}
|
||||
} else if (params.image_grid_pinpoints[0] != 0) {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
std::vector<img_size> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) {
|
||||
img_size s;
|
||||
s.width = params.image_grid_pinpoints[i];
|
||||
s.height = params.image_grid_pinpoints[i+1];
|
||||
possible_resolutions.push_back(s);
|
||||
}
|
||||
img_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions);
|
||||
// debug_image_save_to_bmp(*img, "input.bmp");
|
||||
temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
|
||||
// debug_image_save_to_bmp(*temp, "resized.bmp");
|
||||
|
||||
std::vector<llama_image_u8> patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
|
||||
|
||||
llama_image_u8 image_original_resize;
|
||||
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
patches.insert(patches.begin(), image_original_resize);
|
||||
output_slices.buf.resize(patches.size());
|
||||
int num = 0;
|
||||
for (auto & patch : patches) {
|
||||
normalize_image_u8_to_f32(patch, output_slices.buf[num], params.image_mean, params.image_std);
|
||||
num++;
|
||||
}
|
||||
return output_slices;
|
||||
} else {
|
||||
temp.nx = img.nx;
|
||||
temp.ny = img.ny;
|
||||
temp.buf.resize(img.buf.size());
|
||||
memcpy(temp.buf.data(), img.buf.data(), temp.buf.size());
|
||||
}
|
||||
|
||||
const int nx = temp.nx;
|
||||
const int ny = temp.ny;
|
||||
// bmp_export(temp, "resized_vanilla.bmp");
|
||||
|
||||
const int nx2 = params.image_size;
|
||||
const int ny2 = params.image_size;
|
||||
std::vector<float> res;
|
||||
res.resize(3 * nx2 * ny2);
|
||||
|
||||
const float scale = std::max(nx, ny) / (float)params.image_size;
|
||||
|
||||
const int nx3 = int(nx / scale + 0.5f);
|
||||
const int ny3 = int(ny / scale + 0.5f);
|
||||
|
||||
const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
|
||||
const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f};
|
||||
|
||||
for (int y = 0; y < ny3; y++) {
|
||||
for (int x = 0; x < nx3; x++) {
|
||||
for (int c = 0; c < 3; c++) {
|
||||
// linear interpolation
|
||||
const float sx = (x + 0.5f) * scale - 0.5f;
|
||||
const float sy = (y + 0.5f) * scale - 0.5f;
|
||||
|
||||
const int x0 = std::max(0, (int)std::floor(sx));
|
||||
const int y0 = std::max(0, (int)std::floor(sy));
|
||||
|
||||
const int x1 = std::min(x0 + 1, nx - 1);
|
||||
const int y1 = std::min(y0 + 1, ny - 1);
|
||||
|
||||
const float dx = sx - x0;
|
||||
const float dy = sy - y0;
|
||||
|
||||
const int j00 = 3 * (y0 * nx + x0) + c;
|
||||
const int j01 = 3 * (y0 * nx + x1) + c;
|
||||
const int j10 = 3 * (y1 * nx + x0) + c;
|
||||
const int j11 = 3 * (y1 * nx + x1) + c;
|
||||
|
||||
const float v00 = temp.buf[j00];
|
||||
const float v01 = temp.buf[j01];
|
||||
const float v10 = temp.buf[j10];
|
||||
const float v11 = temp.buf[j11];
|
||||
|
||||
const float v0 = v00 * (1.0f - dx) + v01 * dx;
|
||||
const float v1 = v10 * (1.0f - dx) + v11 * dx;
|
||||
|
||||
const float v = v0 * (1.0f - dy) + v1 * dy;
|
||||
|
||||
const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
|
||||
|
||||
const int i = 3 * (y * nx3 + x) + c;
|
||||
|
||||
res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_slices.buf.resize(1);
|
||||
output_slices.buf[0] = std::move(res);
|
||||
|
||||
return output_slices;
|
||||
};
|
||||
};
|
||||
|
||||
struct llama_vision_processor_uhd : llama_vision_processor {
|
||||
llama_vision_processor_uhd(const llama_vision_context & ctx) : llama_vision_processor(ctx) {}
|
||||
|
||||
int ensure_divide(int length, int patch_size) {
|
||||
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
|
||||
}
|
||||
|
||||
std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
|
||||
std::pair<int, int> find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
|
||||
int width = original_size.first;
|
||||
int height = original_size.second;
|
||||
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
|
||||
@ -268,7 +426,7 @@ struct minicpmv_preprocessor {
|
||||
return std::make_pair(best_width, best_height);
|
||||
}
|
||||
|
||||
std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
|
||||
std::pair<int, int> get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
|
||||
int width, height;
|
||||
std::tie(width, height) = original_size;
|
||||
int grid_x, grid_y;
|
||||
@ -281,7 +439,7 @@ struct minicpmv_preprocessor {
|
||||
int grid_height = refine_height / grid_y;
|
||||
|
||||
// auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
|
||||
auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
|
||||
auto best_grid_size = find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
|
||||
int best_grid_width, best_grid_height;
|
||||
std::tie(best_grid_width, best_grid_height) = best_grid_size;
|
||||
|
||||
@ -290,7 +448,7 @@ struct minicpmv_preprocessor {
|
||||
return refine_size;
|
||||
}
|
||||
|
||||
std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
|
||||
std::pair<int, int> find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
|
||||
std::vector<int> candidate_split_grids_nums;
|
||||
for (int i : {multiple - 1, multiple, multiple + 1}) {
|
||||
if (i == 1 || i > max_slice_nums) {
|
||||
@ -322,8 +480,8 @@ struct minicpmv_preprocessor {
|
||||
return best_grid;
|
||||
}
|
||||
|
||||
std::vector<std::vector<clip_image_u8>> uhd_slice_image(
|
||||
const clip_image_u8 & img,
|
||||
std::vector<std::vector<llama_image_u8>> slice_image(
|
||||
const llama_image_u8 & img,
|
||||
const int max_slice_nums = 9,
|
||||
const int scale_resolution = 448,
|
||||
const int patch_size = 14) {
|
||||
@ -334,30 +492,30 @@ struct minicpmv_preprocessor {
|
||||
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
|
||||
const int multiple = fmin(ceil(ratio), max_slice_nums);
|
||||
|
||||
std::vector<std::vector<clip_image_u8>> images;
|
||||
std::vector<std::vector<llama_image_u8>> images;
|
||||
LLAMA_LOG_DEBUG("%s: multiple %d\n", __func__, multiple);
|
||||
images.push_back(std::vector<clip_image_u8>());
|
||||
images.push_back(std::vector<llama_image_u8>());
|
||||
|
||||
if (multiple <= 1) {
|
||||
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
|
||||
clip_image_u8 source_image;
|
||||
auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true);
|
||||
llama_image_u8 source_image;
|
||||
bicubic_resize(img, source_image, best_size.first, best_size.second);
|
||||
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
|
||||
images[images.size()-1].push_back(source_image);
|
||||
}
|
||||
else if (multiple > 1) {
|
||||
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
|
||||
clip_image_u8 source_image;
|
||||
auto best_size = find_best_resize(original_size, scale_resolution, patch_size);
|
||||
llama_image_u8 source_image;
|
||||
bicubic_resize(img, source_image, best_size.first, best_size.second);
|
||||
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
|
||||
LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.first, best_size.second);
|
||||
images[images.size()-1].push_back(source_image);
|
||||
|
||||
std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
|
||||
std::pair<int, int> best_grid = find_best_grid(max_slice_nums, multiple, log_ratio);
|
||||
LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.first, best_grid.second);
|
||||
|
||||
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
|
||||
clip_image_u8 refine_image;
|
||||
auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
|
||||
llama_image_u8 refine_image;
|
||||
bicubic_resize(img, refine_image, refine_size.first, refine_size.second);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.first, refine_size.second);
|
||||
@ -368,9 +526,9 @@ struct minicpmv_preprocessor {
|
||||
int grid_x = int(width / best_grid.first);
|
||||
int grid_y = int(height / best_grid.second);
|
||||
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
|
||||
images.push_back(std::vector<clip_image_u8>());
|
||||
images.push_back(std::vector<llama_image_u8>());
|
||||
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
|
||||
clip_image_u8 patch;
|
||||
llama_image_u8 patch;
|
||||
patch.nx = grid_x;
|
||||
patch.ny = grid_y;
|
||||
patch.buf.resize(3 * patch.nx * patch.ny);
|
||||
@ -389,173 +547,32 @@ struct minicpmv_preprocessor {
|
||||
}
|
||||
return images;
|
||||
}
|
||||
|
||||
virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override {
|
||||
auto & params = ctx.model->hparams;
|
||||
GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV);
|
||||
|
||||
std::vector<std::vector<llama_image_u8>> imgs = slice_image(img);
|
||||
|
||||
llama_vision_tokens output;
|
||||
output.n_px = get_n_patches_x(ctx);
|
||||
output.n_py = get_n_patches_y(ctx);
|
||||
output.px = params.patch_size;
|
||||
output.py = params.patch_size;
|
||||
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||
std::vector<float> res;
|
||||
normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std);
|
||||
output.buf.push_back(res);
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
};
|
||||
|
||||
static llama_vision_patches clip_image_preprocess_minicpmv(const clip_context & ctx, const clip_image_u8 & img) {
|
||||
auto & params = ctx.model->hparams;
|
||||
GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV);
|
||||
|
||||
static const int max_slice_nums = 9;
|
||||
minicpmv_preprocessor preprocessor;
|
||||
std::vector<std::vector<clip_image_u8>> imgs = preprocessor.uhd_slice_image(img, max_slice_nums);
|
||||
|
||||
llama_vision_patches output_patches;
|
||||
output_patches.n_px = clip_n_patches_x(ctx);
|
||||
output_patches.n_py = clip_n_patches_y(ctx);
|
||||
output_patches.px = params.patch_size;
|
||||
output_patches.py = params.patch_size;
|
||||
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||
std::vector<float> res;
|
||||
normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std);
|
||||
output_patches.buf.push_back(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
static llama_vision_patches clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img) {
|
||||
bool pad_to_square = true;
|
||||
auto & params = ctx.model->hparams;
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) {
|
||||
pad_to_square = false;
|
||||
}
|
||||
|
||||
llama_vision_patches output_patches;
|
||||
output_patches.n_px = clip_n_patches_x(ctx);
|
||||
output_patches.n_py = clip_n_patches_y(ctx);
|
||||
output_patches.px = params.patch_size;
|
||||
output_patches.py = params.patch_size;
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
||||
clip_image_u8 temp;
|
||||
if (pad_to_square && img.nx != img.ny) {
|
||||
// if the image is not square, pad it to a square
|
||||
int longer_side = std::max(img.nx, img.ny);
|
||||
temp.nx = longer_side;
|
||||
temp.ny = longer_side;
|
||||
temp.buf.resize(3 * longer_side * longer_side);
|
||||
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
|
||||
|
||||
// fill with background color
|
||||
for (size_t i = 0; i < temp.buf.size(); i++) {
|
||||
temp.buf[i] = bc[i % 3];
|
||||
}
|
||||
|
||||
// copy from the input image
|
||||
for (int y = 0; y < img.ny; y++) {
|
||||
for (int x = 0; x < img.nx; x++) {
|
||||
const int i = 3 * (y * img.nx + x);
|
||||
const int j = 3 * (y * temp.nx + x);
|
||||
temp.buf[j] = img.buf[i];
|
||||
temp.buf[j+1] = img.buf[i+1];
|
||||
temp.buf[j+2] = img.buf[i+2];
|
||||
}
|
||||
}
|
||||
} else if (params.image_grid_pinpoints[0] != 0) {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
std::vector<clip_image_size> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) {
|
||||
clip_image_size s;
|
||||
s.width = params.image_grid_pinpoints[i];
|
||||
s.height = params.image_grid_pinpoints[i+1];
|
||||
possible_resolutions.push_back(s);
|
||||
}
|
||||
clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions);
|
||||
// clip_image_save_to_bmp(*img, "input.bmp");
|
||||
temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
|
||||
// clip_image_save_to_bmp(*temp, "resized.bmp");
|
||||
|
||||
std::vector<clip_image_u8> patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
|
||||
|
||||
clip_image_u8 image_original_resize;
|
||||
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
patches.insert(patches.begin(), image_original_resize);
|
||||
// clip_image_f32_batch_init(patches.size());
|
||||
output_patches.buf.resize(patches.size());
|
||||
int num = 0;
|
||||
for (auto & patch : patches) {
|
||||
normalize_image_u8_to_f32(patch, output_patches.buf[num], params.image_mean, params.image_std);
|
||||
num++;
|
||||
}
|
||||
return output_patches;
|
||||
} else {
|
||||
temp.nx = img.nx;
|
||||
temp.ny = img.ny;
|
||||
temp.buf.resize(img.buf.size());
|
||||
memcpy(temp.buf.data(), img.buf.data(), temp.buf.size());
|
||||
}
|
||||
|
||||
const int nx = temp.nx;
|
||||
const int ny = temp.ny;
|
||||
// bmp_export(temp, "resized_vanilla.bmp");
|
||||
|
||||
const int nx2 = params.image_size;
|
||||
const int ny2 = params.image_size;
|
||||
std::vector<float> res;
|
||||
res.resize(3 * nx2 * ny2);
|
||||
|
||||
const float scale = std::max(nx, ny) / (float)params.image_size;
|
||||
|
||||
const int nx3 = int(nx / scale + 0.5f);
|
||||
const int ny3 = int(ny / scale + 0.5f);
|
||||
|
||||
const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
|
||||
const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f};
|
||||
|
||||
for (int y = 0; y < ny3; y++) {
|
||||
for (int x = 0; x < nx3; x++) {
|
||||
for (int c = 0; c < 3; c++) {
|
||||
// linear interpolation
|
||||
const float sx = (x + 0.5f) * scale - 0.5f;
|
||||
const float sy = (y + 0.5f) * scale - 0.5f;
|
||||
|
||||
const int x0 = std::max(0, (int)std::floor(sx));
|
||||
const int y0 = std::max(0, (int)std::floor(sy));
|
||||
|
||||
const int x1 = std::min(x0 + 1, nx - 1);
|
||||
const int y1 = std::min(y0 + 1, ny - 1);
|
||||
|
||||
const float dx = sx - x0;
|
||||
const float dy = sy - y0;
|
||||
|
||||
const int j00 = 3 * (y0 * nx + x0) + c;
|
||||
const int j01 = 3 * (y0 * nx + x1) + c;
|
||||
const int j10 = 3 * (y1 * nx + x0) + c;
|
||||
const int j11 = 3 * (y1 * nx + x1) + c;
|
||||
|
||||
const float v00 = temp.buf[j00];
|
||||
const float v01 = temp.buf[j01];
|
||||
const float v10 = temp.buf[j10];
|
||||
const float v11 = temp.buf[j11];
|
||||
|
||||
const float v0 = v00 * (1.0f - dx) + v01 * dx;
|
||||
const float v1 = v10 * (1.0f - dx) + v11 * dx;
|
||||
|
||||
const float v = v0 * (1.0f - dy) + v1 * dy;
|
||||
|
||||
const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
|
||||
|
||||
const int i = 3 * (y * nx3 + x) + c;
|
||||
|
||||
res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_patches.buf.resize(1);
|
||||
output_patches.buf[0] = std::move(res);
|
||||
|
||||
return output_patches;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) {
|
||||
static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int batch_size, img_size & image_size) {
|
||||
auto & model = *ctx.model;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
@ -726,7 +743,7 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
||||
// ne is whcn, ne = [1024, 576, 1, 1]
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||
|
||||
if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
if (hparams.proj_type == VISION_PROJECTOR_TYPE_MLP) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
|
||||
@ -734,7 +751,7 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||
|
||||
} else if (hparams.proj_type == CLIP_PROJECTOR_TYPE_LDPV2) {
|
||||
} else if (hparams.proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
||||
int n_patch = 24;
|
||||
struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
||||
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
||||
@ -770,8 +787,8 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
||||
return gf;
|
||||
}
|
||||
|
||||
static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches & patches) {
|
||||
int batch_size = patches.buf.size();
|
||||
static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) {
|
||||
int batch_size = inp.buf.size();
|
||||
auto & model = *ctx.model;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
@ -779,7 +796,7 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches
|
||||
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
|
||||
}
|
||||
|
||||
clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size};
|
||||
img_size image_size{(int)hparams.image_size, (int)hparams.image_size};
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
|
||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||
@ -788,7 +805,7 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches
|
||||
LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions);
|
||||
|
||||
// build the inference graph
|
||||
ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size);
|
||||
ggml_cgraph * gf = llama_vision_build_graph(ctx, batch_size, image_size);
|
||||
|
||||
// alloc memory for graph
|
||||
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
|
||||
@ -803,15 +820,15 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches
|
||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
const int nx = patches.px * patches.n_px;
|
||||
const int ny = patches.py * patches.n_py;
|
||||
const int nx = inp.px * inp.n_px;
|
||||
const int ny = inp.py * inp.n_py;
|
||||
const int n = nx * ny;
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = patches.buf[b][3 * (y * nx + x) + k];
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = inp.buf[b][3 * (y * nx + x) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -891,34 +908,38 @@ void llama_vision_bitmap_free(llama_vision_bitmap * bmp) {
|
||||
delete bmp;
|
||||
}
|
||||
|
||||
struct llama_vision_patches * llama_vision_patches_init(
|
||||
struct llama_vision_tokens * llama_vision_tokenize(
|
||||
struct llama_context * ctx,
|
||||
llama_vision_bitmap * bmp) {
|
||||
clip_context & vctx = ctx->vctx;
|
||||
if (vctx.model->hparams.arch == LLM_ARCH_VISION_MINICPMV) {
|
||||
return new llama_vision_patches(clip_image_preprocess_minicpmv(vctx, *bmp));
|
||||
llama_vision_context & vctx = ctx->vctx;
|
||||
switch (vctx.model->hparams.arch) {
|
||||
case LLM_ARCH_VISION_LLAVA:
|
||||
case LLM_ARCH_VISION_MOBILEVLM:
|
||||
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
||||
case LLM_ARCH_VISION_MINICPMV:
|
||||
return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp));
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported arch");
|
||||
}
|
||||
return new llama_vision_patches(clip_image_preprocess(vctx, *bmp));
|
||||
}
|
||||
|
||||
void llama_vision_patches_free(llama_vision_patches * p) {
|
||||
void llama_vision_tokens_free(llama_vision_tokens * p) {
|
||||
delete p;
|
||||
}
|
||||
|
||||
int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_patches * p) {
|
||||
int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) {
|
||||
if (p->buf.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
clip_context & vctx = ctx->vctx;
|
||||
llama_vision_context & vctx = ctx->vctx;
|
||||
auto & hparams = vctx.model->hparams;
|
||||
switch (hparams.mm_patch_merge_type) {
|
||||
case MM_PATCH_MERGE_FLAT:
|
||||
{
|
||||
// flat / default llava-1.5 type embedding
|
||||
// n_output = clip_n_patches(ctx);
|
||||
int32_t encoded = clip_image_encode(vctx, *p);
|
||||
int32_t encoded = llama_vision_encode_impl(vctx, *p);
|
||||
if (encoded != 0) {
|
||||
LLAMA_LOG_ERROR("Unable to encode image\n");
|
||||
return encoded;
|
||||
@ -944,7 +965,7 @@ struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) {
|
||||
// for debugging
|
||||
#ifndef NDEBUG
|
||||
|
||||
static int bmp_export(const struct clip_image_u8 &img, const std::string &location) {
|
||||
static int bmp_export(const struct llama_image_u8 &img, const std::string &location) {
|
||||
const uint32_t width = img.nx;
|
||||
const uint32_t height = img.ny;
|
||||
// swap red and blue channel
|
||||
|
@ -7,12 +7,12 @@
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
enum clip_projector_type {
|
||||
CLIP_PROJECTOR_TYPE_UNKNOWN,
|
||||
CLIP_PROJECTOR_TYPE_MLP,
|
||||
CLIP_PROJECTOR_TYPE_LDPV2,
|
||||
CLIP_PROJECTOR_TYPE_MINICPMV_2_5,
|
||||
CLIP_PROJECTOR_TYPE_MINICPMV_2_6,
|
||||
enum vision_projector_type {
|
||||
VISION_PROJECTOR_TYPE_UNKNOWN,
|
||||
VISION_PROJECTOR_TYPE_MLP,
|
||||
VISION_PROJECTOR_TYPE_LDPV2,
|
||||
VISION_PROJECTOR_TYPE_MINICPMV_2_5,
|
||||
VISION_PROJECTOR_TYPE_MINICPMV_2_6,
|
||||
};
|
||||
|
||||
enum mm_patch_merge {
|
||||
@ -21,62 +21,33 @@ enum mm_patch_merge {
|
||||
MM_PATCH_MERGE_SPATIAL_UNPAD,
|
||||
};
|
||||
|
||||
struct clip_hparams {
|
||||
llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||
struct llama_vision_model {
|
||||
struct vision_hparams {
|
||||
llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||
|
||||
uint32_t image_size;
|
||||
uint32_t patch_size;
|
||||
uint32_t hidden_size;
|
||||
uint32_t n_intermediate;
|
||||
uint32_t projection_dim;
|
||||
uint32_t n_head;
|
||||
uint32_t n_layer;
|
||||
uint32_t max_pos_embd;
|
||||
int32_t select_layer = 0;
|
||||
bool use_gelu = false;
|
||||
uint32_t image_size;
|
||||
uint32_t patch_size;
|
||||
uint32_t hidden_size;
|
||||
uint32_t n_intermediate;
|
||||
uint32_t projection_dim;
|
||||
uint32_t n_head;
|
||||
uint32_t n_layer;
|
||||
uint32_t max_pos_embd;
|
||||
int32_t select_layer = 0;
|
||||
bool use_gelu = false;
|
||||
|
||||
float eps;
|
||||
float eps;
|
||||
|
||||
clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN;
|
||||
mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN;
|
||||
vision_projector_type proj_type = VISION_PROJECTOR_TYPE_UNKNOWN;
|
||||
mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN;
|
||||
|
||||
std::array<float, 3> image_mean;
|
||||
std::array<float, 3> image_std;
|
||||
std::array<float, 3> image_mean;
|
||||
std::array<float, 3> image_std;
|
||||
|
||||
std::array<int32_t, 32> image_grid_pinpoints; // TODO: should this be array of (x, y) pairs?
|
||||
int32_t image_crop_resolution;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w = nullptr;
|
||||
struct ggml_tensor * k_b = nullptr;
|
||||
struct ggml_tensor * q_w = nullptr;
|
||||
struct ggml_tensor * q_b = nullptr;
|
||||
struct ggml_tensor * v_w = nullptr;
|
||||
struct ggml_tensor * v_b = nullptr;
|
||||
|
||||
struct ggml_tensor * output_w = nullptr;
|
||||
struct ggml_tensor * output_b = nullptr;
|
||||
|
||||
// layernorm 1
|
||||
struct ggml_tensor * norm_in_w = nullptr;
|
||||
struct ggml_tensor * norm_in_b = nullptr;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_w = nullptr;
|
||||
struct ggml_tensor * ffn_up_b = nullptr;
|
||||
|
||||
struct ggml_tensor * ffn_down_w = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * norm_out_w = nullptr;
|
||||
struct ggml_tensor * norm_out_b = nullptr;
|
||||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_hparams hparams;
|
||||
std::array<int32_t, 32> image_grid_pinpoints; // TODO: should this be array of (x, y) pairs?
|
||||
int32_t image_crop_resolution;
|
||||
};
|
||||
struct vision_hparams hparams;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
|
||||
// embeddings
|
||||
@ -88,7 +59,34 @@ struct clip_vision_model {
|
||||
struct ggml_tensor * pre_norm_w = nullptr;
|
||||
struct ggml_tensor * pre_norm_b = nullptr;
|
||||
|
||||
std::vector<clip_layer> layers;
|
||||
struct vision_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w = nullptr;
|
||||
struct ggml_tensor * k_b = nullptr;
|
||||
struct ggml_tensor * q_w = nullptr;
|
||||
struct ggml_tensor * q_b = nullptr;
|
||||
struct ggml_tensor * v_w = nullptr;
|
||||
struct ggml_tensor * v_b = nullptr;
|
||||
|
||||
struct ggml_tensor * output_w = nullptr;
|
||||
struct ggml_tensor * output_b = nullptr;
|
||||
|
||||
// layernorm 1
|
||||
struct ggml_tensor * norm_in_w = nullptr;
|
||||
struct ggml_tensor * norm_in_b = nullptr;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_w = nullptr;
|
||||
struct ggml_tensor * ffn_up_b = nullptr;
|
||||
|
||||
struct ggml_tensor * ffn_down_w = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * norm_out_w = nullptr;
|
||||
struct ggml_tensor * norm_out_b = nullptr;
|
||||
};
|
||||
std::vector<vision_layer> layers;
|
||||
|
||||
struct ggml_tensor * post_norm_w = nullptr;
|
||||
struct ggml_tensor * post_norm_b = nullptr;
|
||||
@ -132,13 +130,13 @@ struct clip_vision_model {
|
||||
struct ggml_tensor * image_newline = nullptr;
|
||||
};
|
||||
|
||||
struct clip_context {
|
||||
struct llama_vision_context {
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
struct ggml_context * ctx_ggml = nullptr;
|
||||
|
||||
const clip_vision_model * model;
|
||||
const llama_vision_model * model;
|
||||
|
||||
// temporary output data, to be picked up by llama_decode()
|
||||
struct ggml_tensor * output;
|
||||
@ -147,7 +145,7 @@ struct clip_context {
|
||||
// for now, this only contains:
|
||||
// - the instruction for ggml_conv_2d to break the image into patches
|
||||
// - the pre-processed image data in f32
|
||||
struct llama_vision_patches {
|
||||
struct llama_vision_tokens {
|
||||
uint32_t px; // size of patch
|
||||
uint32_t py; // size of patch
|
||||
size_t n_px; // number of patches in x direction
|
||||
@ -166,20 +164,20 @@ inline mm_patch_merge mm_patch_merge_from_name(std::string & name) {
|
||||
return MM_PATCH_MERGE_UNKNOWN;
|
||||
}
|
||||
|
||||
inline clip_projector_type clip_projector_type_from_name(std::string & name) {
|
||||
inline vision_projector_type vision_projector_type_from_name(std::string & name) {
|
||||
if (name == "mlp") {
|
||||
return CLIP_PROJECTOR_TYPE_MLP;
|
||||
return VISION_PROJECTOR_TYPE_MLP;
|
||||
} else if (name == "ldpv2") {
|
||||
return CLIP_PROJECTOR_TYPE_LDPV2;
|
||||
return VISION_PROJECTOR_TYPE_LDPV2;
|
||||
} else if (name == "minicpmv-2.5") {
|
||||
return CLIP_PROJECTOR_TYPE_MINICPMV_2_5;
|
||||
return VISION_PROJECTOR_TYPE_MINICPMV_2_5;
|
||||
} else if (name == "minicpmv-2.6") {
|
||||
return CLIP_PROJECTOR_TYPE_MINICPMV_2_6;
|
||||
return VISION_PROJECTOR_TYPE_MINICPMV_2_6;
|
||||
}
|
||||
return CLIP_PROJECTOR_TYPE_UNKNOWN;
|
||||
return VISION_PROJECTOR_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// only for sanity check: must be equal to n_embd of language model
|
||||
uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model);
|
||||
uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel);
|
||||
|
||||
struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);
|
||||
|
Loading…
Reference in New Issue
Block a user