random pos_embed

This commit is contained in:
caitianchi 2024-05-26 19:40:37 +08:00
parent 629420ee39
commit b48708af22
5 changed files with 389 additions and 323 deletions

1
.gitignore vendored
View File

@ -61,6 +61,7 @@ models-mnt
/llama-bench /llama-bench
/llava-cli /llava-cli
/minicpmv-cli /minicpmv-cli
/openbmb
/lookahead /lookahead
/lookup /lookup
/lookup-create /lookup-create

View File

@ -548,43 +548,99 @@ struct clip_ctx {
ggml_gallocr_t compute_alloc = NULL; ggml_gallocr_t compute_alloc = NULL;
}; };
void print_tensor_info(const struct ggml_tensor *tensor, const std::string &name) { std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>>& pos) {
std::cout << "Tensor " << name << ": (" assert(embed_dim % 2 == 0);
<< tensor->ne[0] << ", " << tensor->ne[1] << ", " << tensor->ne[2] << ")" << std::endl; int H = pos.size();
for (int i = 0; i < tensor->ne[0]; ++i) { int W = pos[0].size();
for (int j = 0; j < tensor->ne[1]; ++j) {
for (int k = 0; k < tensor->ne[2]; ++k) { std::vector<float> omega(embed_dim / 2);
std::cout << ((float *)tensor->data)[i * tensor->ne[1] * tensor->ne[2] + j * tensor->ne[2] + k] << " "; for (int i = 0; i < embed_dim / 2; ++i) {
omega[i] = 1.0 / pow(10000.0, static_cast<float>(i) / (embed_dim / 2));
}
std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
for (int d = 0; d < embed_dim / 2; ++d) {
float out_value = pos[h][w] * omega[d];
emb[h][w][d] = sin(out_value);
emb[h][w][d + embed_dim / 2] = cos(out_value);
} }
std::cout << std::endl;
} }
std::cout << std::endl;
}
}
struct ggml_tensor *ggml_sin(struct ggml_context *ctx, struct ggml_tensor *input) {
int size = input->ne[0] * input->ne[1] * input->ne[2];
struct ggml_tensor *out = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, input->ne[0], input->ne[1], input->ne[2]);
for (int i = 0; i < size; ++i) {
((float *)out->data)[i] = std::sin(((float *)input->data)[i]);
} }
return out; return emb;
} }
struct ggml_tensor *ggml_cos(struct ggml_context *ctx, struct ggml_tensor *input) { std::vector<std::vector<std::vector<float>>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector<std::vector<std::vector<float>>>& grid) {
int size = input->ne[0] * input->ne[1] * input->ne[2]; assert(embed_dim % 2 == 0);
struct ggml_tensor *out = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, input->ne[0], input->ne[1], input->ne[2]); std::vector<std::vector<std::vector<float>>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
std::vector<std::vector<std::vector<float>>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
for (int i = 0; i < size; ++i) { int H = emb_h.size();
((float *)out->data)[i] = std::cos(((float *)input->data)[i]); int W = emb_h[0].size();
std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
for (int d = 0; d < embed_dim / 2; ++d) {
emb[h][w][d] = emb_h[h][w][d];
emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
}
}
}
return emb;
}
struct ggml_tensor * get_2d_sincos_pos_embed(int embed_dim, const std::pair<int, int> image_size, struct ggml_context * ctx, struct ggml_tensor * pos_embed) {
int grid_h_size = image_size.first;
int grid_w_size = image_size.second;
std::vector<float> grid_h(grid_h_size);
std::vector<float> grid_w(grid_w_size);
for (int i = 0; i < grid_h_size; ++i) {
grid_h[i] = static_cast<float>(i);
}
for (int i = 0; i < grid_w_size; ++i) {
grid_w[i] = static_cast<float>(i);
} }
return out; std::vector<std::vector<float>> grid(grid_h_size, std::vector<float>(grid_w_size));
for (int h = 0; h < grid_h_size; ++h) {
for (int w = 0; w < grid_w_size; ++w) {
grid[h][w] = grid_w[w];
}
}
std::vector<std::vector<std::vector<float>>> grid_2d = {grid, grid};
for (int h = 0; h < grid_h_size; ++h) {
for (int w = 0; w < grid_w_size; ++w) {
grid_2d[0][h][w] = grid_h[h];
grid_2d[1][h][w] = grid_w[w];
}
}
std::vector<std::vector<std::vector<float>>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
int H = image_size.first;
int W = image_size.second;
std::vector<std::vector<float>> pos_embed_2d(H * W, std::vector<float>(embed_dim));
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
}
}
float* dataArray = static_cast<float*>(pos_embed->data);
for(int i=0;i<grid_h_size * grid_w_size;++i){
for(int j=0;j<embed_dim;++j){
dataArray[i*embed_dim+j]=pos_embed_2d[i][j];
}
}
return pos_embed;
} }
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, std::pair<int, int> load_image_size = {70, 70}) { static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, std::pair<int, int> load_image_size = {448, 448}) {
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
LOG_TEE("This gguf file seems to have no vision encoder\n"); LOG_TEE("This gguf file seems to have no vision encoder\n");
return nullptr; return nullptr;
@ -593,10 +649,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const auto & model = ctx->vision_model; const auto & model = ctx->vision_model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int image_size = hparams.image_size; const int image_size_width = load_image_size.first;
const int image_size_height = load_image_size.second;
const int patch_size = hparams.patch_size; const int patch_size = hparams.patch_size;
const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side);
const int num_positions = num_patches; const int num_positions = num_patches;
const int hidden_size = hparams.hidden_size; const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
@ -619,8 +675,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
LOG_TEE("%s: ctx->buf_compute_meta.size(): %d \n", __func__, ctx->buf_compute_meta.size()); LOG_TEE("%s: ctx->buf_compute_meta.size(): %d \n", __func__, ctx->buf_compute_meta.size());
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); LOG_TEE("%s: load_image_size: %d %d\n", __func__, load_image_size.first, load_image_size.second);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
ggml_set_name(inp_raw, "inp_raw"); ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw); ggml_set_input(inp_raw);
@ -648,6 +705,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_tensor * embeddings = struct ggml_tensor * embeddings =
ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
struct ggml_tensor * pos_embed = get_2d_sincos_pos_embed(4096, std::make_pair(pos_w, pos_h), ctx0, model.mm_model_pos_embed_k);
pos_embed = ggml_view_3d(ctx0, pos_embed, 4096, pos_w * pos_h, 1, pos_embed->nb[1], pos_embed->nb[2], 0);
// // pre-layernorm // // pre-layernorm
// { // {
// embeddings = ggml_norm(ctx0, embeddings, eps); // embeddings = ggml_norm(ctx0, embeddings, eps);
@ -933,7 +995,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} }
{ // position { // position
// q = ggml_add(ctx0, q, model.mm_model_pos_embed); // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
k = ggml_add(ctx0, v, model.mm_model_pos_embed_k); k = ggml_add(ctx0, v, pos_embed);
} }
{ // attention { // attention
@ -985,7 +1047,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} }
// read and create ggml_context containing the tensors and their data // read and create ggml_context containing the tensors and their data
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1, std::pair<int, int> load_image_size = {70, 70}) { struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1, std::pair<int, int> load_image_size = {448, 448}) {
struct ggml_context * meta = NULL; struct ggml_context * meta = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
@ -1718,95 +1780,103 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
temp->nx = img->nx;
temp->ny = img->ny;
temp->buf.resize(img->buf.size());
memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
if (pad_to_square && img->nx != img->ny) { // if (pad_to_square && img->nx != img->ny) {
int longer_side = std::max(img->nx, img->ny); // int longer_side = std::max(img->nx, img->ny);
temp->nx = longer_side; // temp->nx = img->nx;
temp->ny = longer_side; // temp->ny = longer_side;
temp->buf.resize(3 * longer_side * longer_side); // temp->buf.resize(3 * longer_side * longer_side);
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) // const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
// fill with background color // // fill with background color
for (size_t i = 0; i < temp->buf.size(); i++) { // for (size_t i = 0; i < temp->buf.size(); i++) {
temp->buf[i] = bc[i % 3]; // temp->buf[i] = bc[i % 3];
} // }
// copy from the input image // // copy from the input image
for (int y = 0; y < img->ny; y++) { // for (int y = 0; y < img->ny; y++) {
for (int x = 0; x < img->nx; x++) { // for (int x = 0; x < img->nx; x++) {
const int i = 3 * (y * img->nx + x); // const int i = 3 * (y * img->nx + x);
const int j = 3 * (y * temp->nx + x); // const int j = 3 * (y * temp->nx + x);
temp->buf[j] = img->buf[i]; // temp->buf[j] = img->buf[i];
temp->buf[j+1] = img->buf[i+1]; // temp->buf[j+1] = img->buf[i+1];
temp->buf[j+2] = img->buf[i+2]; // temp->buf[j+2] = img->buf[i+2];
} // }
} // }
} else { // } else {
if (params.image_grid_pinpoints[0] != 0) { // if (params.image_grid_pinpoints[0] != 0) {
// "spatial_unpad" with "anyres" processing for llava-1.6 // // "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions; // std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { // for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); // possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
} // }
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); // std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
// clip_image_save_to_bmp(*img, "input.bmp"); // // clip_image_save_to_bmp(*img, "input.bmp");
resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 // resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
// clip_image_save_to_bmp(*temp, "resized.bmp"); // // clip_image_save_to_bmp(*temp, "resized.bmp");
// visually verify normalized image: // // visually verify normalized image:
// normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); // // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
// { // // {
// clip_image_u8 * temp2 = clip_image_u8_init(); // // clip_image_u8 * temp2 = clip_image_u8_init();
// clip_image_convert_f32_to_u8(*res, *temp2); // // clip_image_convert_f32_to_u8(*res, *temp2);
// clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp"); // // clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
// clip_image_u8_free(temp2); // // clip_image_u8_free(temp2);
// } // // }
std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) // std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
clip_image_u8 *image_original_resize = clip_image_u8_init(); // clip_image_u8 *image_original_resize = clip_image_u8_init();
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square // // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square // bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
patches.insert(patches.begin(), image_original_resize); // patches.insert(patches.begin(), image_original_resize);
// clip_image_f32_batch_init(patches.size()); // // clip_image_f32_batch_init(patches.size());
res_imgs->size = patches.size(); // res_imgs->size = patches.size();
res_imgs->data = new clip_image_f32[res_imgs->size]; // res_imgs->data = new clip_image_f32[res_imgs->size];
int num=0; // int num=0;
for (auto& patch : patches) { // for (auto& patch : patches) {
normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std); // normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
num++; // num++;
} // }
for (size_t i = 0; i < patches.size(); i++) { // for (size_t i = 0; i < patches.size(); i++) {
// LOG_TEE("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); // // LOG_TEE("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
clip_image_u8_free(patches[i]); // clip_image_u8_free(patches[i]);
} // }
clip_image_u8_free(temp); // clip_image_u8_free(temp);
return true; // return true;
} else { // } else {
temp->nx = img->nx; // temp->nx = img->nx;
temp->ny = img->ny; // temp->ny = img->ny;
temp->buf.resize(img->buf.size()); // temp->buf.resize(img->buf.size());
memcpy(temp->buf.data(), img->buf.data(), temp->buf.size()); // memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
} // }
} // }
const int nx = temp->nx; const int nx = temp->nx;
const int ny = temp->ny; const int ny = temp->ny;
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
const int nx2 = ctx->vision_model.hparams.image_size; const int nx2 = temp->nx;
const int ny2 = ctx->vision_model.hparams.image_size; const int ny2 = temp->ny;
clip_image_f32 * res = clip_image_f32_init(); clip_image_f32 * res = clip_image_f32_init();
res->nx = nx2; res->nx = nx2;
res->ny = ny2; res->ny = ny2;
res->buf.resize(3 * nx2 * ny2); res->buf.resize(3 * nx2 * ny2);
const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size; // const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;
const int nx3 = int(nx / scale + 0.5f); // const int nx3 = int(nx / scale + 0.5f);
const int ny3 = int(ny / scale + 0.5f); // const int ny3 = int(ny / scale + 0.5f);
const int nx3 = nx;
const int ny3 = ny;
const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f};
@ -1815,8 +1885,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
for (int x = 0; x < nx3; x++) { for (int x = 0; x < nx3; x++) {
for (int c = 0; c < 3; c++) { for (int c = 0; c < 3; c++) {
// linear interpolation // linear interpolation
const float sx = (x + 0.5f) * scale - 0.5f; const float sx = x;
const float sy = (y + 0.5f) * scale - 0.5f; const float sy = y;
const int x0 = std::max(0, (int)std::floor(sx)); const int x0 = std::max(0, (int)std::floor(sx));
const int y0 = std::max(0, (int)std::floor(sy)); const int y0 = std::max(0, (int)std::floor(sy));
@ -1920,7 +1990,7 @@ int clip_n_patches(const struct clip_ctx * ctx) {
return n_patches; return n_patches;
} }
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, std::pair<int, int> load_image_size = {70, 70}) { bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, std::pair<int, int> load_image_size = {448, 448}) {
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
LOG_TEE("This gguf file seems to have no vision encoder\n"); LOG_TEE("This gguf file seems to have no vision encoder\n");
return false; return false;
@ -1932,7 +2002,7 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
return clip_image_batch_encode(ctx, n_threads, &imgs, vec, load_image_size); return clip_image_batch_encode(ctx, n_threads, &imgs, vec, load_image_size);
} }
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec, std::pair<int, int> load_image_size = {70, 70}) { bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec, std::pair<int, int> load_image_size = {448, 448}) {
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
LOG_TEE("This gguf file seems to have no vision encoder\n"); LOG_TEE("This gguf file seems to have no vision encoder\n");
return false; return false;
@ -1963,7 +2033,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
for (size_t i = 0; i < imgs->size; i++) { for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx; const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny; const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size); // GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny; const int n = nx * ny;

View File

@ -275,11 +275,11 @@ def _replace_name_resampler(s, v):
if re.match("resampler.pos_embed", s): if re.match("resampler.pos_embed", s):
return { return {
s: v, s: v,
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (448//14, 448//14))), re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
} }
if re.match("resampler.proj", s): if re.match("resampler.proj", s):
return { return {
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (448//14, 448//14))), re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
} }
if re.match("resampler.attn.in_proj_.*", s): if re.match("resampler.attn.in_proj_.*", s):

View File

@ -31,191 +31,191 @@ struct clip_image_grid_shape {
int second; int second;
}; };
/** // /**
* Selects the best resolution from a list of possible resolutions based on the original size. // * Selects the best resolution from a list of possible resolutions based on the original size.
* // *
* @param original_size The original size of the image in the format (width, height). // * @param original_size The original size of the image in the format (width, height).
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. // * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
* @return The best fit resolution in the format (width, height). // * @return The best fit resolution in the format (width, height).
*/ // */
static std::pair<int, int> select_best_resolution(const std::pair<int, int>& original_size, const std::vector<std::pair<int, int>>& possible_resolutions) { // static std::pair<int, int> select_best_resolution(const std::pair<int, int>& original_size, const std::vector<std::pair<int, int>>& possible_resolutions) {
int original_width = original_size.first; // int original_width = original_size.first;
int original_height = original_size.second; // int original_height = original_size.second;
std::pair<int, int> best_fit; // std::pair<int, int> best_fit;
int max_effective_resolution = 0; // int max_effective_resolution = 0;
int min_wasted_resolution = std::numeric_limits<int>::max(); // int min_wasted_resolution = std::numeric_limits<int>::max();
for (const auto& resolution : possible_resolutions) { // for (const auto& resolution : possible_resolutions) {
int width = resolution.first; // int width = resolution.first;
int height = resolution.second; // int height = resolution.second;
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height); // float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
int downscaled_width = static_cast<int>(original_width * scale); // int downscaled_width = static_cast<int>(original_width * scale);
int downscaled_height = static_cast<int>(original_height * scale); // int downscaled_height = static_cast<int>(original_height * scale);
int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); // int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
int wasted_resolution = (width * height) - effective_resolution; // int wasted_resolution = (width * height) - effective_resolution;
// LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); // // LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { // if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
max_effective_resolution = effective_resolution; // max_effective_resolution = effective_resolution;
min_wasted_resolution = wasted_resolution; // min_wasted_resolution = wasted_resolution;
best_fit = resolution; // best_fit = resolution;
} // }
} // }
return best_fit; // return best_fit;
} // }
/** // /**
* @brief Get the anyres image grid shape object // * @brief Get the anyres image grid shape object
* // *
* @param image_size // * @param image_size
* @param grid_pinpoints // * @param grid_pinpoints
* @param image_patch_size // * @param image_patch_size
* @return <int, int> // * @return <int, int>
*/ // */
static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int> & image_size, const std::vector<std::pair<int, int>> & grid_pinpoints, int image_patch_size) { // static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int> & image_size, const std::vector<std::pair<int, int>> & grid_pinpoints, int image_patch_size) {
/** // /**
Conversion from gguf flat array to vector: // Conversion from gguf flat array to vector:
std::vector<std::pair<int, int>> possible_resolutions; // std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { // for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); // possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
} // }
*/ // */
auto best_resolution = select_best_resolution(image_size, grid_pinpoints); // auto best_resolution = select_best_resolution(image_size, grid_pinpoints);
return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; // return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size};
} // }
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) // // Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { // static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
struct { // struct {
struct ggml_tensor * newline; // struct ggml_tensor * newline;
struct ggml_context * ctx; // struct ggml_context * ctx;
} model; // } model;
const int32_t image_size = clip_image_size(ctx_clip); // const int32_t image_size = clip_image_size(ctx_clip);
const int32_t patch_size = clip_patch_size(ctx_clip); // const int32_t patch_size = clip_patch_size(ctx_clip);
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) // int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
int num_patches_width = grid_shape.first; // grid 1-4 // int num_patches_width = grid_shape.first; // grid 1-4
int num_patches_height = grid_shape.second; // grid 1-4 // int num_patches_height = grid_shape.second; // grid 1-4
const size_t num_images = num_patches_width * num_patches_height + 1; // const size_t num_images = num_patches_width * num_patches_height + 1;
// TODO: size calculation is not calculated - it's only tens of MB // // TODO: size calculation is not calculated - it's only tens of MB
size_t ctx_size = 0; // size_t ctx_size = 0;
{ // {
ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features // ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features
ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); // ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32);
} // }
struct ggml_init_params params { // struct ggml_init_params params {
/*.mem_size =*/ ctx_size, // /*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL, // /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API // /*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API
}; // };
// Python reference code for full unpad: // // Python reference code for full unpad:
/* // /*
base_image_feature = image_feature[0] // base_image_feature = image_feature[0]
image_feature = image_feature[1:] // image_feature = image_feature[1:]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() // image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3) // image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx]) // image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(( // image_feature = torch.cat((
image_feature, // image_feature,
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1) // self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)
), dim=-1) // ), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1) // image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0) // image_feature = torch.cat((base_image_feature, image_feature), dim=0)
*/ // */
// We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval. // // We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval.
// In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet. // // In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet.
// Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them. // // Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them.
// Once all images are processed to prepended the base_image_features without any changes. // // Once all images are processed to prepended the base_image_features without any changes.
// Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling)) // // Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling))
/* // /*
image_feature = image_feature.view(2, 2, 24, 24, 4096) // image_feature = image_feature.view(2, 2, 24, 24, 4096)
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() // image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.view(2, 24, 2, 24, 4096) // image_feature = image_feature.view(2, 24, 2, 24, 4096)
image_feature = image_feature.flatten(0, 3) // image_feature = image_feature.flatten(0, 3)
// Reshape to 4D tensor by merging the last two dimensions // // Reshape to 4D tensor by merging the last two dimensions
image_feature = image_feature.view(2, 2, 24, 24*4096) // image_feature = image_feature.view(2, 2, 24, 24*4096)
image_feature = image_feature.permute(0, 2, 1, 3).contiguous() // image_feature = image_feature.permute(0, 2, 1, 3).contiguous()
image_feature = image_feature.view(-1, 4096) // image_feature = image_feature.view(-1, 4096)
*/ // */
model.ctx = ggml_init(params); // model.ctx = ggml_init(params);
ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip); // ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip);
model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]); // model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]);
if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) { // if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) {
if (newline_tmp->buffer == NULL) { // if (newline_tmp->buffer == NULL) {
LOG_TEE("newline_tmp tensor buffer is NULL\n"); // LOG_TEE("newline_tmp tensor buffer is NULL\n");
} // }
ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp)); // ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp));
} else { // } else {
model.newline->data = newline_tmp->data; // model.newline->data = newline_tmp->data;
if (model.newline->data == NULL) { // if (model.newline->data == NULL) {
LOG_TEE("newline_tmp tensor data is NULL\n"); // LOG_TEE("newline_tmp tensor data is NULL\n");
} // }
} // }
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 // struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); // // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
// fill it with the image embeddings, ignoring the base // // fill it with the image embeddings, ignoring the base
for (size_t i = 1; i < num_images; i++) { // for (size_t i = 1; i < num_images; i++) {
size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); // size_t offset = (i-1) * clip_embd_nbytes(ctx_clip);
memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); // memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip));
} // }
struct ggml_cgraph * gf = ggml_new_graph(model.ctx); // struct ggml_cgraph * gf = ggml_new_graph(model.ctx);
size_t size_ele = ggml_type_size(GGML_TYPE_F32); // size_t size_ele = ggml_type_size(GGML_TYPE_F32);
struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features, // struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features,
num_patches_per_side * clip_n_mmproj_embd(ctx_clip), // num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
num_patches_per_side, // num_patches_per_side,
num_patches_width, // num_patches_width,
num_patches_height, // num_patches_height,
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip), // size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side, // size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side,
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0); // size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0);
// ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false); // // ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false);
struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); // struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3));
/** // /**
At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings // At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings
image_feature = torch.cat(( // image_feature = torch.cat((
image_feature, // image_feature,
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) // self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1) // ), dim=-1)
* // *
*/ // */
// ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false); // // ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false);
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); // struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); // // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
ggml_build_forward_expand(gf, flatten); // ggml_build_forward_expand(gf, flatten);
ggml_graph_compute_with_ctx(model.ctx, gf, 1); // ggml_graph_compute_with_ctx(model.ctx, gf, 1);
struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1]; // struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1];
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context // memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
// append without newline tokens (default behavior in llava_arch when not using unpad ): // // append without newline tokens (default behavior in llava_arch when not using unpad ):
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches // memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip)); // *n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
// Debug: Test single segments // // Debug: Test single segments
// Current findings: sending base image, sending a segment embedding all works similar to python // // Current findings: sending base image, sending a segment embedding all works similar to python
// However, permuted embeddings do not work yet (stride issue?) // // However, permuted embeddings do not work yet (stride issue?)
// memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context // // memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context
// memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context // // memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context
// *n_img_pos_out=576; // // *n_img_pos_out=576;
ggml_free(model.ctx); // ggml_free(model.ctx);
return true; // return true;
} // }
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
@ -254,52 +254,53 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
return false; return false;
} }
} else { }
// spatial_unpad llava-1.6 type embedding // else {
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working // // spatial_unpad llava-1.6 type embedding
std::vector<float *> image_embd_v; // // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
image_embd_v.resize(img_res_v.size); // std::vector<float *> image_embd_v;
for (size_t i = 0; i < img_res_v.size; i++) { // image_embd_v.resize(img_res_v.size);
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 // for (size_t i = 0; i < img_res_v.size; i++) {
const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i], load_image_size); // image data is in 3x336x336 format and will be converted to 336x336x3 inside // image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
if (!encoded) { // const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i], load_image_size); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); // if (!encoded) {
return false; // LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
} // return false;
} // }
const int64_t t_img_enc_batch_us = ggml_time_us(); // }
LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); // const int64_t t_img_enc_batch_us = ggml_time_us();
// LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
const int32_t * image_grid = clip_image_grid(ctx_clip); // const int32_t * image_grid = clip_image_grid(ctx_clip);
std::vector<std::pair<int, int>> grid_pinpoints; // std::vector<std::pair<int, int>> grid_pinpoints;
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { // for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); // grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
} // }
// free all img_res_v - not needed anymore // // free all img_res_v - not needed anymore
delete[] img_res_v.data; // delete[] img_res_v.data;
img_res_v.size = 0; // img_res_v.size = 0;
img_res_v.data = nullptr; // img_res_v.data = nullptr;
const int32_t image_size = clip_image_size(ctx_clip); // const int32_t image_size = clip_image_size(ctx_clip);
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); // struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
int n_img_pos_out; // int n_img_pos_out;
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); // clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
*n_img_pos = n_img_pos_out; // *n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) { // for (size_t i = 0; i < image_embd_v.size(); i++) {
free(image_embd_v[i]); // free(image_embd_v[i]);
} // }
image_embd_v.clear(); // image_embd_v.clear();
// debug image/segment/normalization content: // // debug image/segment/normalization content:
// clip_image_u8 * tmp = clip_image_u8_init(); // // clip_image_u8 * tmp = clip_image_u8_init();
// clip_image_convert_f32_to_u8(*image_feature, *tmp); // // clip_image_convert_f32_to_u8(*image_feature, *tmp);
// clip_image_save_to_bmp(*tmp, "image_feature.bmp"); // // clip_image_save_to_bmp(*tmp, "image_feature.bmp");
} // }
LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);

View File

@ -60,7 +60,7 @@ struct clip_ctx * clip_init_context(gpt_params * params) {
if (prompt.empty()) { if (prompt.empty()) {
prompt = "describe the image in detail."; prompt = "describe the image in detail.";
} }
std::pair<int, int> load_image_size = std::make_pair(70, 70); std::pair<int, int> load_image_size = std::make_pair(448, 448);
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1, load_image_size); auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1, load_image_size);
return ctx_clip; return ctx_clip;
} }
@ -112,27 +112,21 @@ void process_image(struct minicpmv_context * ctx_llava, std::vector<std::vector<
LOG_TEE("%s: image token past: %d\n", __func__, n_past); LOG_TEE("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, true); eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, true);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[0][0], params->n_batch, &n_past); llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[0][0], params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
if (image_embed_slices.size() > 1) { if (image_embed_slices.size() > 1) {
eval_string(ctx_llava->ctx_llama, std::string("</image><slice>").c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
for (size_t i = 1; i < image_embed_slices.size(); ++i) { for (size_t i = 1; i < image_embed_slices.size(); ++i) {
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
for (size_t j = 0; j < image_embed_slices[i].size(); ++j) { for (size_t j = 0; j < image_embed_slices[i].size(); ++j) {
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[i][j], params->n_batch, &n_past); llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[i][j], params->n_batch, &n_past);
if (j != image_embed_slices[i].size() - 1) { eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
eval_string(ctx_llava->ctx_llama, std::string("</image><image>").c_str(), params->n_batch, &n_past, false); if (j == image_embed_slices[i].size() - 1) {
} else { eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
if (i != image_embed_slices.size() - 1) {
eval_string(ctx_llava->ctx_llama, std::string("</image>\n").c_str(), params->n_batch, &n_past, false);
} else {
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
}
} }
} }
} }
eval_string(ctx_llava->ctx_llama, std::string("</slice>\n").c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
} else {
eval_string(ctx_llava->ctx_llama, std::string("</image>\n").c_str(), params->n_batch, &n_past, false);
} }
LOG_TEE("%s: image token past: %d\n", __func__, n_past); LOG_TEE("%s: image token past: %d\n", __func__, n_past);
} }