diff --git a/examples/vision/vision.cpp b/examples/vision/vision.cpp index 73f8ef1b6..88b5be5bb 100644 --- a/examples/vision/vision.cpp +++ b/examples/vision/vision.cpp @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { int n_prompt = 0; // process image - llama_vision_patches * img_patches = nullptr; + llama_vision_tokens * img_tokens = nullptr; { const char * img_path = params.image[0].c_str(); if (params.image[0].empty()) { @@ -131,12 +131,12 @@ int main(int argc, char ** argv) { } llama_vision_bitmap * img = load_image_from_file(img_path); LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny); - img_patches = llama_vision_patches_init(ctx, img); - if (!img_patches) { - LOG_ERR("failed to create image patches\n"); + img_tokens = llama_vision_tokenize(ctx, img); + if (!img_tokens) { + LOG_ERR("failed to create image tokens\n"); return 1; } - if (llama_vision_encode(ctx, img_patches)) { + if (llama_vision_encode(ctx, img_tokens)) { LOG_ERR("failed to encode image\n"); return 1; } diff --git a/include/llama.h b/include/llama.h index bd8e69658..c230e0c3d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -229,7 +229,9 @@ extern "C" { bool sorted; } llama_token_data_array; - struct llama_vision_patches; + // Structure represents the basic input unit of vision model + // This can be a processed image or slices of images under the hood + struct llama_vision_tokens; // represent an RGB image // size of data must be equal to 3*nx*ny @@ -1286,12 +1288,15 @@ extern "C" { LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny); LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp); - // Create patches from the RGB bitmap - LLAMA_API struct llama_vision_patches * llama_vision_patches_init(struct llama_context * ctx, llama_vision_bitmap * bmp); - LLAMA_API void llama_vision_patches_free(struct llama_vision_patches * p); + // Create image tokens from the RGB bitmap + LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp); + LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens); + + // User must reserve N number of tokens in tokenized text prompt for each image + // LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens); // Encode patches into embeddings - LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p); + LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens); LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx); // diff --git a/src/llama-context.h b/src/llama-context.h index 10c839f55..f2704b89e 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,7 +110,7 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] // vision - clip_context vctx; + llama_vision_context vctx; }; // TODO: make these methods of llama_context diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0ea66d254..9e81dafe8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1268,8 +1268,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { { std::string name; ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true); - vparams.proj_type = clip_projector_type_from_name(name); - if (vparams.proj_type == CLIP_PROJECTOR_TYPE_UNKNOWN) { + vparams.proj_type = vision_projector_type_from_name(name); + if (vparams.proj_type == VISION_PROJECTOR_TYPE_UNKNOWN) { throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str())); } } @@ -3514,7 +3514,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error("unknown vision architecture"); } - if (clip_n_mmproj_embd(clip) != hparams.n_embd) { + if (llama_vision_n_mmproj_embd(clip) != hparams.n_embd) { std::runtime_error("model has vision, but n_mmproj_embd != n_embd"); } } diff --git a/src/llama-model.h b/src/llama-model.h index fd3820f1e..d7a17d993 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -365,7 +365,7 @@ struct llama_model { // vision bool has_vision = false; - clip_vision_model clip; + llama_vision_model clip; private: struct impl; diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index e348d31da..c583593b7 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -13,25 +13,25 @@ #include #include -// export clip_image_u8 to bmp file for debugging +// export llama_image_u8 to bmp file for debugging // https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c -struct clip_image_size; -static int bmp_export(const struct clip_image_u8 &img, const std::string &location); +struct img_size; +static int bmp_export(const struct llama_image_u8 &img, const std::string &location); #endif -struct clip_image_size { +struct img_size { int width; int height; }; // RGB uint8 image // Memory layout: RGBRGBRGB... -struct clip_image_u8 { +struct llama_image_u8 { int nx; int ny; std::vector buf; - clip_image_u8() {} - clip_image_u8(const llama_vision_bitmap & bmp) { + llama_image_u8() {} + llama_image_u8(const llama_vision_bitmap & bmp) { nx = bmp.nx; ny = bmp.ny; buf.resize(nx*ny*3); @@ -39,39 +39,43 @@ struct clip_image_u8 { } }; -struct clip_image_u8_batch { - struct clip_image_u8 * data; - size_t size; -}; - -static int clip_n_patches_x(const clip_context & ctx) { - auto & hparams = ctx.model->hparams; - return hparams.image_size / hparams.patch_size; -} - -static int clip_n_patches_y(const clip_context & ctx) { - return clip_n_patches_x(ctx); -} - -static int clip_n_patches(const clip_context & ctx) { - return clip_n_patches_x(ctx) * clip_n_patches_y(ctx); -} - -uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) { - auto & proj_type = clip_model.hparams.proj_type; - if (proj_type == CLIP_PROJECTOR_TYPE_MLP) { - return clip_model.mm_2_b->ne[0]; - } else if (proj_type == CLIP_PROJECTOR_TYPE_LDPV2) { - return clip_model.mm_model_peg_0_b->ne[0]; - } else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_5) { +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) { + auto & proj_type = vmodel.hparams.proj_type; + if (proj_type == VISION_PROJECTOR_TYPE_MLP) { + return vmodel.mm_2_b->ne[0]; + } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { + return vmodel.mm_model_peg_0_b->ne[0]; + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { return 4096; - } else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_6) { + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) { return 3584; } else { GGML_ASSERT(false && "invalid proj type"); } } + +// +// internal utils +// + +static int get_n_patches_x(const llama_vision_context & ctx) { + auto & hparams = ctx.model->hparams; + return hparams.image_size / hparams.patch_size; +} + +static int get_n_patches_y(const llama_vision_context & ctx) { + return get_n_patches_x(ctx); +} + +static int get_n_patches(const llama_vision_context & ctx) { + return get_n_patches_x(ctx) * get_n_patches_y(ctx); +} + +// +// bitmap utils +// + /** * Selects the best resolution from a list of possible resolutions based on the original size. * @@ -79,11 +83,11 @@ uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) { * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. * @return The best fit resolution in the format (width, height). */ -static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector& possible_resolutions) { +static img_size select_best_resolution(const img_size & original_size, const std::vector& possible_resolutions) { int original_width = original_size.width; int original_height = original_size.height; - clip_image_size best_fit; + img_size best_fit; int max_effective_resolution = 0; int min_wasted_resolution = std::numeric_limits::max(); @@ -106,7 +110,7 @@ static clip_image_size select_best_resolution(const clip_image_size & original_s return best_fit; } -static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { +static bool bicubic_resize(const llama_image_u8 & img, llama_image_u8 & dst, int target_width, int target_height) { auto clip = [](int x, int lower, int upper) -> int { return std::max(lower, std::min(x, upper)); }; @@ -173,13 +177,13 @@ static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int t return true; } -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; +static std::vector divide_to_patches_u8(const llama_image_u8 & image, int patch_size) { + std::vector patches; int width = image.nx; int height = image.ny; for (int i = 0; i < height; i += patch_size) { for (int j = 0; j < width; j += patch_size) { - clip_image_u8 patch; + llama_image_u8 patch; patch.nx = std::min(patch_size, width - j); patch.ny = std::min(patch_size, height - i); patch.buf.resize(3 * patch.nx * patch.ny); @@ -197,7 +201,7 @@ static std::vector divide_to_patches_u8(const clip_image_u8 & ima } // llava-1.6 type of resize_and_pad (black) -static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const clip_image_size & target_resolution) { +static llama_image_u8 resize_and_pad_image(const llama_image_u8 & image, const img_size & target_resolution) { int target_width = target_resolution.width; int target_height = target_resolution.height; @@ -214,11 +218,11 @@ static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const cli new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); } - clip_image_u8 resized_image; + llama_image_u8 resized_image; // bilinear_resize(image, resized_image, new_width, new_height); bicubic_resize(image, resized_image, new_width, new_height); - clip_image_u8 padded_image; + llama_image_u8 padded_image; padded_image.nx = target_width; padded_image.ny = target_height; padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black @@ -238,7 +242,7 @@ static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const cli return padded_image; } -static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vector & dst, const std::array & mean, const std::array & std) { +static void normalize_image_u8_to_f32(const llama_image_u8 & src, std::vector & dst, const std::array & mean, const std::array & std) { dst.resize(src.buf.size()); for (size_t i = 0; i < src.buf.size(); ++i) { @@ -247,15 +251,169 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vectorhparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + llama_vision_tokens output_slices; + output_slices.n_px = get_n_patches_x(ctx); + output_slices.n_py = get_n_patches_y(ctx); + output_slices.px = params.patch_size; + output_slices.py = params.patch_size; + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + llama_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + // if the image is not square, pad it to a square + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + img_size s; + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + img_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + // debug_image_save_to_bmp(*img, "input.bmp"); + temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // debug_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + llama_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + output_slices.buf.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_slices.buf[num], params.image_mean, params.image_std); + num++; + } + return output_slices; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + + const int nx = temp.nx; + const int ny = temp.ny; + // bmp_export(temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + std::vector res; + res.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_slices.buf.resize(1); + output_slices.buf[0] = std::move(res); + + return output_slices; + }; +}; + +struct llama_vision_processor_uhd : llama_vision_processor { + llama_vision_processor_uhd(const llama_vision_context & ctx) : llama_vision_processor(ctx) {} + int ensure_divide(int length, int patch_size) { return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); } - std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + std::pair find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { int width = original_size.first; int height = original_size.second; if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { @@ -268,7 +426,7 @@ struct minicpmv_preprocessor { return std::make_pair(best_width, best_height); } - std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + std::pair get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { int width, height; std::tie(width, height) = original_size; int grid_x, grid_y; @@ -281,7 +439,7 @@ struct minicpmv_preprocessor { int grid_height = refine_height / grid_y; // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) - auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair + auto best_grid_size = find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair int best_grid_width, best_grid_height; std::tie(best_grid_width, best_grid_height) = best_grid_size; @@ -290,7 +448,7 @@ struct minicpmv_preprocessor { return refine_size; } - std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + std::pair find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { std::vector candidate_split_grids_nums; for (int i : {multiple - 1, multiple, multiple + 1}) { if (i == 1 || i > max_slice_nums) { @@ -322,8 +480,8 @@ struct minicpmv_preprocessor { return best_grid; } - std::vector> uhd_slice_image( - const clip_image_u8 & img, + std::vector> slice_image( + const llama_image_u8 & img, const int max_slice_nums = 9, const int scale_resolution = 448, const int patch_size = 14) { @@ -334,30 +492,30 @@ struct minicpmv_preprocessor { const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); - std::vector> images; + std::vector> images; LLAMA_LOG_DEBUG("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); + images.push_back(std::vector()); if (multiple <= 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8 source_image; + auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true); + llama_image_u8 source_image; bicubic_resize(img, source_image, best_size.first, best_size.second); // source_image = image.resize(best_size, Image.Resampling.BICUBIC) images[images.size()-1].push_back(source_image); } else if (multiple > 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8 source_image; + auto best_size = find_best_resize(original_size, scale_resolution, patch_size); + llama_image_u8 source_image; bicubic_resize(img, source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.first, best_size.second); images[images.size()-1].push_back(source_image); - std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); + std::pair best_grid = find_best_grid(max_slice_nums, multiple, log_ratio); LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.first, best_grid.second); - auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8 refine_image; + auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); + llama_image_u8 refine_image; bicubic_resize(img, refine_image, refine_size.first, refine_size.second); LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.first, refine_size.second); @@ -368,9 +526,9 @@ struct minicpmv_preprocessor { int grid_x = int(width / best_grid.first); int grid_y = int(height / best_grid.second); for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); + images.push_back(std::vector()); for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8 patch; + llama_image_u8 patch; patch.nx = grid_x; patch.ny = grid_y; patch.buf.resize(3 * patch.nx * patch.ny); @@ -389,173 +547,32 @@ struct minicpmv_preprocessor { } return images; } + + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { + auto & params = ctx.model->hparams; + GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV); + + std::vector> imgs = slice_image(img); + + llama_vision_tokens output; + output.n_px = get_n_patches_x(ctx); + output.n_py = get_n_patches_y(ctx); + output.px = params.patch_size; + output.py = params.patch_size; + + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + std::vector res; + normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std); + output.buf.push_back(res); + } + } + + return output; + } }; -static llama_vision_patches clip_image_preprocess_minicpmv(const clip_context & ctx, const clip_image_u8 & img) { - auto & params = ctx.model->hparams; - GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV); - - static const int max_slice_nums = 9; - minicpmv_preprocessor preprocessor; - std::vector> imgs = preprocessor.uhd_slice_image(img, max_slice_nums); - - llama_vision_patches output_patches; - output_patches.n_px = clip_n_patches_x(ctx); - output_patches.n_py = clip_n_patches_y(ctx); - output_patches.px = params.patch_size; - output_patches.py = params.patch_size; - - for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - std::vector res; - normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std); - output_patches.buf.push_back(res); - } - } -} - -// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector -// res_imgs memory is being allocated here, previous allocations will be freed if found -static llama_vision_patches clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img) { - bool pad_to_square = true; - auto & params = ctx.model->hparams; - // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { - pad_to_square = false; - } - - llama_vision_patches output_patches; - output_patches.n_px = clip_n_patches_x(ctx); - output_patches.n_py = clip_n_patches_y(ctx); - output_patches.px = params.patch_size; - output_patches.py = params.patch_size; - - // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) - // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - - clip_image_u8 temp; - if (pad_to_square && img.nx != img.ny) { - // if the image is not square, pad it to a square - int longer_side = std::max(img.nx, img.ny); - temp.nx = longer_side; - temp.ny = longer_side; - temp.buf.resize(3 * longer_side * longer_side); - const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) - - // fill with background color - for (size_t i = 0; i < temp.buf.size(); i++) { - temp.buf[i] = bc[i % 3]; - } - - // copy from the input image - for (int y = 0; y < img.ny; y++) { - for (int x = 0; x < img.nx; x++) { - const int i = 3 * (y * img.nx + x); - const int j = 3 * (y * temp.nx + x); - temp.buf[j] = img.buf[i]; - temp.buf[j+1] = img.buf[i+1]; - temp.buf[j+2] = img.buf[i+2]; - } - } - } else if (params.image_grid_pinpoints[0] != 0) { - // "spatial_unpad" with "anyres" processing for llava-1.6 - std::vector possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { - clip_image_size s; - s.width = params.image_grid_pinpoints[i]; - s.height = params.image_grid_pinpoints[i+1]; - possible_resolutions.push_back(s); - } - clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); - // clip_image_save_to_bmp(*img, "input.bmp"); - temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 - // clip_image_save_to_bmp(*temp, "resized.bmp"); - - std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - - clip_image_u8 image_original_resize; - // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), image_original_resize); - // clip_image_f32_batch_init(patches.size()); - output_patches.buf.resize(patches.size()); - int num = 0; - for (auto & patch : patches) { - normalize_image_u8_to_f32(patch, output_patches.buf[num], params.image_mean, params.image_std); - num++; - } - return output_patches; - } else { - temp.nx = img.nx; - temp.ny = img.ny; - temp.buf.resize(img.buf.size()); - memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); - } - - const int nx = temp.nx; - const int ny = temp.ny; - // bmp_export(temp, "resized_vanilla.bmp"); - - const int nx2 = params.image_size; - const int ny2 = params.image_size; - std::vector res; - res.resize(3 * nx2 * ny2); - - const float scale = std::max(nx, ny) / (float)params.image_size; - - const int nx3 = int(nx / scale + 0.5f); - const int ny3 = int(ny / scale + 0.5f); - - const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; - const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; - - for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x++) { - for (int c = 0; c < 3; c++) { - // linear interpolation - const float sx = (x + 0.5f) * scale - 0.5f; - const float sy = (y + 0.5f) * scale - 0.5f; - - const int x0 = std::max(0, (int)std::floor(sx)); - const int y0 = std::max(0, (int)std::floor(sy)); - - const int x1 = std::min(x0 + 1, nx - 1); - const int y1 = std::min(y0 + 1, ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = 3 * (y0 * nx + x0) + c; - const int j01 = 3 * (y0 * nx + x1) + c; - const int j10 = 3 * (y1 * nx + x0) + c; - const int j11 = 3 * (y1 * nx + x1) + c; - - const float v00 = temp.buf[j00]; - const float v01 = temp.buf[j01]; - const float v10 = temp.buf[j10]; - const float v11 = temp.buf[j11]; - - const float v0 = v00 * (1.0f - dx) + v01 * dx; - const float v1 = v10 * (1.0f - dx) + v11 * dx; - - const float v = v0 * (1.0f - dy) + v1 * dy; - - const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); - - const int i = 3 * (y * nx3 + x) + c; - - res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; - } - } - } - - output_patches.buf.resize(1); - output_patches.buf[0] = std::move(res); - - return output_patches; -} - -static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) { +static ggml_cgraph * llama_vision_build_graph(llama_vision_context & ctx, int batch_size, img_size & image_size) { auto & model = *ctx.model; auto & hparams = ctx.model->hparams; @@ -726,7 +743,7 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, // ne is whcn, ne = [1024, 576, 1, 1] embeddings = ggml_get_rows(ctx0, embeddings, patches); - if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + if (hparams.proj_type == VISION_PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); @@ -734,7 +751,7 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } else if (hparams.proj_type == CLIP_PROJECTOR_TYPE_LDPV2) { + } else if (hparams.proj_type == VISION_PROJECTOR_TYPE_LDPV2) { int n_patch = 24; struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); @@ -770,8 +787,8 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, return gf; } -static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches & patches) { - int batch_size = patches.buf.size(); +static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) { + int batch_size = inp.buf.size(); auto & model = *ctx.model; auto & hparams = ctx.model->hparams; @@ -779,7 +796,7 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches GGML_ASSERT(batch_size == 1); // TODO: support multiple images } - clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size}; + img_size image_size{(int)hparams.image_size, (int)hparams.image_size}; const int patch_size = hparams.patch_size; const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); @@ -788,7 +805,7 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); // build the inference graph - ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size); + ggml_cgraph * gf = llama_vision_build_graph(ctx, batch_size, image_size); // alloc memory for graph bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf); @@ -803,15 +820,15 @@ static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches float * data = (float *)malloc(ggml_nbytes(inp_raw)); for (int i = 0; i < batch_size; i++) { - const int nx = patches.px * patches.n_px; - const int ny = patches.py * patches.n_py; + const int nx = inp.px * inp.n_px; + const int ny = inp.py * inp.n_py; const int n = nx * ny; for (int b = 0; b < batch_size; b++) { for (int k = 0; k < 3; k++) { for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = patches.buf[b][3 * (y * nx + x) + k]; + data[(b * 3 * n) + k * n + y * nx + x] = inp.buf[b][3 * (y * nx + x) + k]; } } } @@ -891,34 +908,38 @@ void llama_vision_bitmap_free(llama_vision_bitmap * bmp) { delete bmp; } -struct llama_vision_patches * llama_vision_patches_init( +struct llama_vision_tokens * llama_vision_tokenize( struct llama_context * ctx, llama_vision_bitmap * bmp) { - clip_context & vctx = ctx->vctx; - if (vctx.model->hparams.arch == LLM_ARCH_VISION_MINICPMV) { - return new llama_vision_patches(clip_image_preprocess_minicpmv(vctx, *bmp)); + llama_vision_context & vctx = ctx->vctx; + switch (vctx.model->hparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); + case LLM_ARCH_VISION_MINICPMV: + return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp)); + default: + GGML_ASSERT(false && "unsupported arch"); } - return new llama_vision_patches(clip_image_preprocess(vctx, *bmp)); } -void llama_vision_patches_free(llama_vision_patches * p) { +void llama_vision_tokens_free(llama_vision_tokens * p) { delete p; } -int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_patches * p) { +int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) { if (p->buf.empty()) { LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__); return -1; } - clip_context & vctx = ctx->vctx; + llama_vision_context & vctx = ctx->vctx; auto & hparams = vctx.model->hparams; switch (hparams.mm_patch_merge_type) { case MM_PATCH_MERGE_FLAT: { // flat / default llava-1.5 type embedding - // n_output = clip_n_patches(ctx); - int32_t encoded = clip_image_encode(vctx, *p); + int32_t encoded = llama_vision_encode_impl(vctx, *p); if (encoded != 0) { LLAMA_LOG_ERROR("Unable to encode image\n"); return encoded; @@ -944,7 +965,7 @@ struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) { // for debugging #ifndef NDEBUG -static int bmp_export(const struct clip_image_u8 &img, const std::string &location) { +static int bmp_export(const struct llama_image_u8 &img, const std::string &location) { const uint32_t width = img.nx; const uint32_t height = img.ny; // swap red and blue channel diff --git a/src/llama-vision.h b/src/llama-vision.h index a9304867f..374ae4537 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -7,12 +7,12 @@ #include #include -enum clip_projector_type { - CLIP_PROJECTOR_TYPE_UNKNOWN, - CLIP_PROJECTOR_TYPE_MLP, - CLIP_PROJECTOR_TYPE_LDPV2, - CLIP_PROJECTOR_TYPE_MINICPMV_2_5, - CLIP_PROJECTOR_TYPE_MINICPMV_2_6, +enum vision_projector_type { + VISION_PROJECTOR_TYPE_UNKNOWN, + VISION_PROJECTOR_TYPE_MLP, + VISION_PROJECTOR_TYPE_LDPV2, + VISION_PROJECTOR_TYPE_MINICPMV_2_5, + VISION_PROJECTOR_TYPE_MINICPMV_2_6, }; enum mm_patch_merge { @@ -21,62 +21,33 @@ enum mm_patch_merge { MM_PATCH_MERGE_SPATIAL_UNPAD, }; -struct clip_hparams { - llm_arch arch = LLM_ARCH_UNKNOWN; +struct llama_vision_model { + struct vision_hparams { + llm_arch arch = LLM_ARCH_UNKNOWN; - uint32_t image_size; - uint32_t patch_size; - uint32_t hidden_size; - uint32_t n_intermediate; - uint32_t projection_dim; - uint32_t n_head; - uint32_t n_layer; - uint32_t max_pos_embd; - int32_t select_layer = 0; - bool use_gelu = false; + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + int32_t select_layer = 0; + bool use_gelu = false; - float eps; + float eps; - clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN; - mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN; + vision_projector_type proj_type = VISION_PROJECTOR_TYPE_UNKNOWN; + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN; - std::array image_mean; - std::array image_std; + std::array image_mean; + std::array image_std; - std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? - int32_t image_crop_resolution; -}; - -struct clip_layer { - // attention - struct ggml_tensor * k_w = nullptr; - struct ggml_tensor * k_b = nullptr; - struct ggml_tensor * q_w = nullptr; - struct ggml_tensor * q_b = nullptr; - struct ggml_tensor * v_w = nullptr; - struct ggml_tensor * v_b = nullptr; - - struct ggml_tensor * output_w = nullptr; - struct ggml_tensor * output_b = nullptr; - - // layernorm 1 - struct ggml_tensor * norm_in_w = nullptr; - struct ggml_tensor * norm_in_b = nullptr; - - // ff - struct ggml_tensor * ffn_up_w = nullptr; - struct ggml_tensor * ffn_up_b = nullptr; - - struct ggml_tensor * ffn_down_w = nullptr; - struct ggml_tensor * ffn_down_b = nullptr; - - // layernorm 2 - struct ggml_tensor * norm_out_w = nullptr; - struct ggml_tensor * norm_out_b = nullptr; -}; - -struct clip_vision_model { - struct clip_hparams hparams; + std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? + int32_t image_crop_resolution; + }; + struct vision_hparams hparams; ggml_backend_buffer_type_t buft; // embeddings @@ -88,7 +59,34 @@ struct clip_vision_model { struct ggml_tensor * pre_norm_w = nullptr; struct ggml_tensor * pre_norm_b = nullptr; - std::vector layers; + struct vision_layer { + // attention + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; + + struct ggml_tensor * output_w = nullptr; + struct ggml_tensor * output_b = nullptr; + + // layernorm 1 + struct ggml_tensor * norm_in_w = nullptr; + struct ggml_tensor * norm_in_b = nullptr; + + // ff + struct ggml_tensor * ffn_up_w = nullptr; + struct ggml_tensor * ffn_up_b = nullptr; + + struct ggml_tensor * ffn_down_w = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; + + // layernorm 2 + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; + }; + std::vector layers; struct ggml_tensor * post_norm_w = nullptr; struct ggml_tensor * post_norm_b = nullptr; @@ -132,13 +130,13 @@ struct clip_vision_model { struct ggml_tensor * image_newline = nullptr; }; -struct clip_context { +struct llama_vision_context { // memory buffers used to evaluate the model std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; struct ggml_context * ctx_ggml = nullptr; - const clip_vision_model * model; + const llama_vision_model * model; // temporary output data, to be picked up by llama_decode() struct ggml_tensor * output; @@ -147,7 +145,7 @@ struct clip_context { // for now, this only contains: // - the instruction for ggml_conv_2d to break the image into patches // - the pre-processed image data in f32 -struct llama_vision_patches { +struct llama_vision_tokens { uint32_t px; // size of patch uint32_t py; // size of patch size_t n_px; // number of patches in x direction @@ -166,20 +164,20 @@ inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { return MM_PATCH_MERGE_UNKNOWN; } -inline clip_projector_type clip_projector_type_from_name(std::string & name) { +inline vision_projector_type vision_projector_type_from_name(std::string & name) { if (name == "mlp") { - return CLIP_PROJECTOR_TYPE_MLP; + return VISION_PROJECTOR_TYPE_MLP; } else if (name == "ldpv2") { - return CLIP_PROJECTOR_TYPE_LDPV2; + return VISION_PROJECTOR_TYPE_LDPV2; } else if (name == "minicpmv-2.5") { - return CLIP_PROJECTOR_TYPE_MINICPMV_2_5; + return VISION_PROJECTOR_TYPE_MINICPMV_2_5; } else if (name == "minicpmv-2.6") { - return CLIP_PROJECTOR_TYPE_MINICPMV_2_6; + return VISION_PROJECTOR_TYPE_MINICPMV_2_6; } - return CLIP_PROJECTOR_TYPE_UNKNOWN; + return VISION_PROJECTOR_TYPE_UNKNOWN; } // only for sanity check: must be equal to n_embd of language model -uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model); +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel); struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);