2023-11-06 22:36:23 +01:00
# include "ggml.h"
# include "common.h"
# include "clip.h"
# include "llava.h"
# include "llama.h"
# include "base64.hpp"
# include <cstdio>
# include <cstdlib>
# include <vector>
static bool eval_tokens ( struct llama_context * ctx_llama , std : : vector < llama_token > tokens , int n_batch , int * n_past ) {
int N = ( int ) tokens . size ( ) ;
for ( int i = 0 ; i < N ; i + = n_batch ) {
int n_eval = ( int ) tokens . size ( ) - i ;
if ( n_eval > n_batch ) {
n_eval = n_batch ;
}
if ( llama_decode ( ctx_llama , llama_batch_get_one ( & tokens [ i ] , n_eval , * n_past , 0 ) ) ) {
fprintf ( stderr , " %s : failed to eval. token %d/%d (batch size %d, n_past %d) \n " , __func__ , i , N , n_batch , * n_past ) ;
return false ;
}
* n_past + = n_eval ;
}
return true ;
}
static bool eval_id ( struct llama_context * ctx_llama , int id , int * n_past ) {
std : : vector < llama_token > tokens ;
tokens . push_back ( id ) ;
return eval_tokens ( ctx_llama , tokens , 1 , n_past ) ;
}
static bool eval_string ( struct llama_context * ctx_llama , const char * str , int n_batch , int * n_past , bool add_bos ) {
std : : string str2 = str ;
std : : vector < llama_token > embd_inp = : : llama_tokenize ( ctx_llama , str2 , add_bos ) ;
eval_tokens ( ctx_llama , embd_inp , n_batch , n_past ) ;
return true ;
}
2023-12-29 15:38:38 +01:00
static const char * sample ( struct llama_sampling_context * ctx_sampling ,
struct llama_context * ctx_llama ,
int * n_past ) {
const llama_token id = llama_sampling_sample ( ctx_sampling , ctx_llama , NULL ) ;
llama_sampling_accept ( ctx_sampling , ctx_llama , id , true ) ;
2023-11-06 22:36:23 +01:00
static std : : string ret ;
if ( id = = llama_token_eos ( llama_get_model ( ctx_llama ) ) ) {
ret = " </s> " ;
} else {
ret = llama_token_to_piece ( ctx_llama , id ) ;
}
eval_id ( ctx_llama , id , n_past ) ;
return ret . c_str ( ) ;
}
static const char * IMG_BASE64_TAG_BEGIN = " <img src= \" data:image/jpeg;base64, " ;
static const char * IMG_BASE64_TAG_END = " \" > " ;
static void find_image_tag_in_prompt ( const std : : string & prompt , size_t & begin_out , size_t & end_out ) {
begin_out = prompt . find ( IMG_BASE64_TAG_BEGIN ) ;
end_out = prompt . find ( IMG_BASE64_TAG_END , ( begin_out = = std : : string : : npos ) ? 0UL : begin_out ) ;
}
static bool prompt_contains_image ( const std : : string & prompt ) {
size_t begin , end ;
find_image_tag_in_prompt ( prompt , begin , end ) ;
return ( begin ! = std : : string : : npos ) ;
}
// replaces the base64 image tag in the prompt with `replacement`
static llava_image_embed * llava_image_embed_make_with_prompt_base64 ( struct clip_ctx * ctx_clip , int n_threads , const std : : string & prompt ) {
size_t img_base64_str_start , img_base64_str_end ;
find_image_tag_in_prompt ( prompt , img_base64_str_start , img_base64_str_end ) ;
if ( img_base64_str_start = = std : : string : : npos | | img_base64_str_end = = std : : string : : npos ) {
fprintf ( stderr , " %s: invalid base64 image tag. must be %s<base64 byte string>%s \n " , __func__ , IMG_BASE64_TAG_BEGIN , IMG_BASE64_TAG_END ) ;
return NULL ;
}
auto base64_bytes_start = img_base64_str_start + strlen ( IMG_BASE64_TAG_BEGIN ) ;
auto base64_bytes_count = img_base64_str_end - base64_bytes_start ;
auto base64_str = prompt . substr ( base64_bytes_start , base64_bytes_count ) ;
auto required_bytes = base64 : : required_encode_size ( base64_str . size ( ) ) ;
auto img_bytes = std : : vector < unsigned char > ( required_bytes ) ;
base64 : : decode ( base64_str . begin ( ) , base64_str . end ( ) , img_bytes . begin ( ) ) ;
auto embed = llava_image_embed_make_with_bytes ( ctx_clip , n_threads , img_bytes . data ( ) , img_bytes . size ( ) ) ;
if ( ! embed ) {
fprintf ( stderr , " %s: could not load image from base64 string. \n " , __func__ ) ;
return NULL ;
}
return embed ;
}
static std : : string remove_image_from_prompt ( const std : : string & prompt , const char * replacement = " " ) {
size_t begin , end ;
find_image_tag_in_prompt ( prompt , begin , end ) ;
if ( begin = = std : : string : : npos | | end = = std : : string : : npos ) {
return prompt ;
}
auto pre = prompt . substr ( 0 , begin ) ;
auto post = prompt . substr ( end + strlen ( IMG_BASE64_TAG_END ) ) ;
return pre + replacement + post ;
}
struct llava_context {
struct clip_ctx * ctx_clip = NULL ;
struct llama_context * ctx_llama = NULL ;
struct llama_model * model = NULL ;
} ;
static void show_additional_info ( int /*argc*/ , char * * argv ) {
2023-12-29 15:38:38 +01:00
fprintf ( stderr , " \n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \" describe the image in detail. \" ] \n " , argv [ 0 ] ) ;
fprintf ( stderr , " note: a lower temperature value like 0.1 is recommended for better quality. \n " ) ;
2023-11-06 22:36:23 +01:00
}
static struct llava_image_embed * load_image ( llava_context * ctx_llava , gpt_params * params ) {
// load and preprocess the image
llava_image_embed * embed = NULL ;
auto prompt = params - > prompt ;
if ( prompt_contains_image ( prompt ) ) {
if ( ! params - > image . empty ( ) ) {
2023-12-29 15:38:38 +01:00
fprintf ( stderr , " using base64 encoded image instead of command line image path \n " ) ;
2023-11-06 22:36:23 +01:00
}
embed = llava_image_embed_make_with_prompt_base64 ( ctx_llava - > ctx_clip , params - > n_threads , prompt ) ;
if ( ! embed ) {
fprintf ( stderr , " %s: can't load image from prompt \n " , __func__ ) ;
return NULL ;
}
params - > prompt = remove_image_from_prompt ( prompt ) ;
} else {
embed = llava_image_embed_make_with_filename ( ctx_llava - > ctx_clip , params - > n_threads , params - > image . c_str ( ) ) ;
if ( ! embed ) {
fprintf ( stderr , " %s: is %s really an image file? \n " , __func__ , params - > image . c_str ( ) ) ;
return NULL ;
}
}
return embed ;
}
static void process_prompt ( struct llava_context * ctx_llava , struct llava_image_embed * image_embed , gpt_params * params , const std : : string & prompt ) {
int n_past = 0 ;
const int max_tgt_len = params - > n_predict < 0 ? 256 : params - > n_predict ;
2023-11-17 03:14:37 +01:00
const bool add_bos = llama_should_add_bos_token ( llama_get_model ( ctx_llava - > ctx_llama ) ) ;
2023-11-06 22:36:23 +01:00
// llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:"
2023-11-17 03:14:37 +01:00
eval_string ( ctx_llava - > ctx_llama , " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. \n USER: " , params - > n_batch , & n_past , add_bos ) ;
2023-11-06 22:36:23 +01:00
llava_eval_image_embed ( ctx_llava - > ctx_llama , image_embed , params - > n_batch , & n_past ) ;
eval_string ( ctx_llava - > ctx_llama , ( prompt + " \n ASSISTANT: " ) . c_str ( ) , params - > n_batch , & n_past , false ) ;
// generate the response
2023-12-29 15:38:38 +01:00
fprintf ( stderr , " \n " ) ;
struct llama_sampling_context * ctx_sampling = llama_sampling_init ( params - > sparams ) ;
2023-11-06 22:36:23 +01:00
for ( int i = 0 ; i < max_tgt_len ; i + + ) {
2023-12-29 15:38:38 +01:00
const char * tmp = sample ( ctx_sampling , ctx_llava - > ctx_llama , & n_past ) ;
2023-11-06 22:36:23 +01:00
if ( strcmp ( tmp , " </s> " ) = = 0 ) break ;
printf ( " %s " , tmp ) ;
fflush ( stdout ) ;
}
2023-12-29 15:38:38 +01:00
llama_sampling_free ( ctx_sampling ) ;
2023-11-06 22:36:23 +01:00
printf ( " \n " ) ;
}
static struct llava_context * llava_init ( gpt_params * params ) {
const char * clip_path = params - > mmproj . c_str ( ) ;
auto prompt = params - > prompt ;
if ( prompt . empty ( ) ) {
prompt = " describe the image in detail. " ;
}
auto ctx_clip = clip_model_load ( clip_path , /*verbosity=*/ 1 ) ;
llama_backend_init ( params - > numa ) ;
2023-11-07 08:43:59 +01:00
llama_model_params model_params = llama_model_params_from_gpt_params ( * params ) ;
2023-11-06 22:36:23 +01:00
llama_model * model = llama_load_model_from_file ( params - > model . c_str ( ) , model_params ) ;
if ( model = = NULL ) {
fprintf ( stderr , " %s: error: unable to load model \n " , __func__ ) ;
return NULL ;
}
2023-11-07 08:43:59 +01:00
llama_context_params ctx_params = llama_context_params_from_gpt_params ( * params ) ;
2023-11-06 22:36:23 +01:00
ctx_params . n_ctx = params - > n_ctx < 2048 ? 2048 : params - > n_ctx ; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model ( model , ctx_params ) ;
if ( ctx_llama = = NULL ) {
fprintf ( stderr , " %s: error: failed to create the llama_context \n " , __func__ ) ;
return NULL ;
}
auto ctx_llava = ( struct llava_context * ) malloc ( sizeof ( llava_context ) ) ;
ctx_llava - > ctx_llama = ctx_llama ;
ctx_llava - > ctx_clip = ctx_clip ;
ctx_llava - > model = model ;
return ctx_llava ;
}
static void llava_free ( struct llava_context * ctx_llava ) {
if ( ctx_llava - > ctx_clip ) {
clip_free ( ctx_llava - > ctx_clip ) ;
ctx_llava - > ctx_clip = NULL ;
}
llama_free ( ctx_llava - > ctx_llama ) ;
llama_free_model ( ctx_llava - > model ) ;
llama_backend_free ( ) ;
}
int main ( int argc , char * * argv ) {
ggml_time_init ( ) ;
gpt_params params ;
if ( ! gpt_params_parse ( argc , argv , params ) ) {
show_additional_info ( argc , argv ) ;
return 1 ;
}
if ( params . mmproj . empty ( ) | | ( params . image . empty ( ) & & ! prompt_contains_image ( params . prompt ) ) ) {
gpt_print_usage ( argc , argv , params ) ;
show_additional_info ( argc , argv ) ;
return 1 ;
}
auto ctx_llava = llava_init ( & params ) ;
if ( ctx_llava = = NULL ) {
fprintf ( stderr , " %s: error: failed to init llava \n " , __func__ ) ;
return 1 ;
}
auto image_embed = load_image ( ctx_llava , & params ) ;
2024-01-09 18:59:14 +01:00
if ( ! image_embed ) {
return 1 ;
}
2023-11-06 22:36:23 +01:00
// process the prompt
process_prompt ( ctx_llava , image_embed , & params , params . prompt ) ;
llama_print_timings ( ctx_llava - > ctx_llama ) ;
llava_image_embed_free ( image_embed ) ;
llava_free ( ctx_llava ) ;
return 0 ;
}