2023-10-12 17:23:18 +02:00
# include "clip.h"
# include "common.h"
# include "llama.h"
2023-11-06 22:36:23 +01:00
# include "llava.h"
2023-10-12 17:23:18 +02:00
# include <cstdio>
# include <cstdlib>
# include <vector>
2023-11-06 22:36:23 +01:00
# include "base64.hpp"
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
static bool encode_image_with_clip ( clip_ctx * ctx_clip , int n_threads , const clip_image_u8 * img , float * image_embd , int * n_img_pos ) {
clip_image_f32 * img_res = make_clip_image_f32 ( ) ;
if ( ! clip_image_preprocess ( ctx_clip , img , img_res , /*pad2square =*/ true ) ) {
fprintf ( stderr , " %s: unable to preprocess image \n " , __func__ ) ;
clip_image_f32_free ( img_res ) ;
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
* n_img_pos = clip_n_patches ( ctx_clip ) ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
const int64_t t_img_enc_start_us = ggml_time_us ( ) ;
bool encoded = clip_image_encode ( ctx_clip , n_threads , img_res , image_embd ) ;
clip_image_f32_free ( img_res ) ;
if ( ! encoded ) {
fprintf ( stderr , " Unable to encode image \n " ) ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
const int64_t t_img_enc_end_us = ggml_time_us ( ) ;
float t_img_enc_ms = ( t_img_enc_end_us - t_img_enc_start_us ) / 1000.0 ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
printf ( " \n %s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch) \n " , __func__ , t_img_enc_ms , t_img_enc_ms / * n_img_pos ) ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
return true ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
bool llava_validate_embed_size ( const llama_context * ctx_llama , const clip_ctx * ctx_clip ) {
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd ( llama_get_model ( ctx_llama ) ) ;
auto n_image_embd = clip_n_mmproj_embd ( ctx_clip ) ;
if ( n_image_embd ! = n_llama_embd ) {
printf ( " %s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file. \n " , __func__ , n_image_embd , n_llama_embd ) ;
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
return true ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
static bool llava_image_embed_make_with_clip_img ( clip_ctx * ctx_clip , int n_threads , const clip_image_u8 * img , float * * image_embd_out , int * n_img_pos_out ) {
2023-10-12 17:23:18 +02:00
float * image_embd = ( float * ) malloc ( clip_embd_nbytes ( ctx_clip ) ) ;
if ( ! image_embd ) {
fprintf ( stderr , " Unable to allocate memory for image embeddings \n " ) ;
2023-11-06 22:36:23 +01:00
free ( image_embd ) ;
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
int n_img_pos ;
if ( ! encode_image_with_clip ( ctx_clip , n_threads , img , image_embd , & n_img_pos ) ) {
fprintf ( stderr , " %s: cannot encode image, aborting \n " , __func__ ) ;
free ( image_embd ) ;
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
* image_embd_out = image_embd ;
* n_img_pos_out = n_img_pos ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
return true ;
}
2023-10-14 12:52:44 +02:00
2023-11-06 22:36:23 +01:00
bool llava_eval_image_embed ( llama_context * ctx_llama , const struct llava_image_embed * image_embed , int n_batch , int * n_past ) {
int n_embd = llama_n_embd ( llama_get_model ( ctx_llama ) ) ;
for ( int i = 0 ; i < image_embed - > n_image_pos ; i + = n_batch ) {
int n_eval = image_embed - > n_image_pos - i ;
if ( n_eval > n_batch ) {
n_eval = n_batch ;
}
llama_batch batch = { int32_t ( n_eval ) , nullptr , ( image_embed - > embed + i * n_embd ) , nullptr , nullptr , nullptr , nullptr , * n_past , 1 , 0 , } ;
if ( llama_decode ( ctx_llama , batch ) ) {
fprintf ( stderr , " %s : failed to eval \n " , __func__ ) ;
return false ;
}
* n_past + = n_eval ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
return true ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes ( struct clip_ctx * ctx_clip , int n_threads , const unsigned char * image_bytes , int image_bytes_length ) {
clip_image_u8 * img = make_clip_image_u8 ( ) ;
if ( ! clip_image_load_from_bytes ( image_bytes , image_bytes_length , img ) ) {
clip_image_u8_free ( img ) ;
fprintf ( stderr , " %s: can't load image from bytes, is it a valid image? " , __func__ ) ;
return NULL ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
float * image_embed = NULL ;
int n_image_pos = 0 ;
bool image_embed_result = llava_image_embed_make_with_clip_img ( ctx_clip , n_threads , img , & image_embed , & n_image_pos ) ;
if ( ! image_embed_result ) {
clip_image_u8_free ( img ) ;
fprintf ( stderr , " %s: coulnd't embed the image \n " , __func__ ) ;
return NULL ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
clip_image_u8_free ( img ) ;
auto result = ( llava_image_embed * ) malloc ( sizeof ( llava_image_embed ) ) ;
result - > embed = image_embed ;
result - > n_image_pos = n_image_pos ;
return result ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
static bool load_file_to_bytes ( const char * path , unsigned char * * bytesOut , long * sizeOut ) {
auto file = fopen ( path , " rb " ) ;
if ( file = = NULL ) {
fprintf ( stderr , " %s: can't read file %s \n " , __func__ , path ) ;
return false ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
fseek ( file , 0 , SEEK_END ) ;
auto fileSize = ftell ( file ) ;
fseek ( file , 0 , SEEK_SET ) ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
auto buffer = ( unsigned char * ) malloc ( fileSize ) ; // Allocate memory to hold the file data
if ( buffer = = NULL ) {
fprintf ( stderr , " %s: failed to alloc %ld bytes for file %s \n " , __func__ , fileSize , path ) ;
perror ( " Memory allocation error " ) ;
fclose ( file ) ;
return false ;
2023-10-12 17:23:18 +02:00
}
2023-11-17 16:22:56 +01:00
errno = 0 ;
size_t ret = fread ( buffer , 1 , fileSize , file ) ; // Read the file into the buffer
if ( ferror ( file ) ) {
die_fmt ( " read error: %s " , strerror ( errno ) ) ;
}
if ( ret ! = ( size_t ) fileSize ) {
die ( " unexpectedly reached end of file " ) ;
}
2023-11-06 22:36:23 +01:00
fclose ( file ) ; // Close the file
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
* bytesOut = buffer ;
* sizeOut = fileSize ;
return true ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename ( struct clip_ctx * ctx_clip , int n_threads , const char * image_path ) {
unsigned char * image_bytes ;
long image_bytes_length ;
auto loaded = load_file_to_bytes ( image_path , & image_bytes , & image_bytes_length ) ;
if ( ! loaded ) {
fprintf ( stderr , " %s: failed to load %s \n " , __func__ , image_path ) ;
return NULL ;
2023-10-12 17:23:18 +02:00
}
2023-11-06 22:36:23 +01:00
auto embed = llava_image_embed_make_with_bytes ( ctx_clip , n_threads , image_bytes , image_bytes_length ) ;
free ( image_bytes ) ;
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
return embed ;
}
2023-10-12 17:23:18 +02:00
2023-11-06 22:36:23 +01:00
LLAVA_API void llava_image_embed_free ( struct llava_image_embed * embed ) {
free ( embed - > embed ) ;
free ( embed ) ;
2023-10-12 17:23:18 +02:00
}