diff --git a/common/common.cpp b/common/common.cpp index aa494291d..fe84039f7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -893,7 +893,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - params.image = argv[i]; + params.image.emplace_back(argv[i]); return true; } if (arg == "-i" || arg == "--interactive") { @@ -1495,7 +1495,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); + printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } diff --git a/common/common.h b/common/common.h index eea63a114..3233d90e6 100644 --- a/common/common.h +++ b/common/common.h @@ -167,8 +167,8 @@ struct gpt_params { std::string cache_type_v = "f16"; // KV cache data type for the V // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file + std::string mmproj = ""; // path to multimodal projector + std::vector image; // path to image file(s) }; bool parse_kv_override(const char * data, std::vector & overrides); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index a44c6cd76..157a680b5 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -113,11 +113,11 @@ struct llava_context { }; static void show_additional_info(int /*argc*/, char ** argv) { - LOG_TEE("\n example usage: %s -m --mmproj --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG_TEE("\n example usage: %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) { +static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) { // load and preprocess the image llava_image_embed * embed = NULL; @@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para } params->prompt = remove_image_from_prompt(prompt); } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str()); + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str()); if (!embed) { - LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str()); + fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); return NULL; } } @@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ printf("\n"); } - -static struct llava_context * llava_init(gpt_params * params) { - const char * clip_path = params->mmproj.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - +static struct llama_model * llava_init(gpt_params * params) { llama_backend_init(); llama_numa_init(params->numa); @@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) { LOG_TEE("%s: error: unable to load model\n" , __func__); return NULL; } + return model; +} + +static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -286,24 +289,30 @@ int main(int argc, char ** argv) { show_additional_info(argc, argv); return 1; } - - auto ctx_llava = llava_init(¶ms); - if (ctx_llava == NULL) { - LOG_TEE("%s: error: failed to init llava\n", __func__); + auto model = llava_init(¶ms); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init llava model\n", __func__); return 1; } - auto image_embed = load_image(ctx_llava, ¶ms); - if (!image_embed) { - return 1; + for (auto & image : params.image) { + auto ctx_llava = llava_init_context(¶ms, model); + + auto image_embed = load_image(ctx_llava, ¶ms, image); + if (!image_embed) { + std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; + return 1; + } + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_print_timings(ctx_llava->ctx_llama); + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); } + llama_free_model(model); - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_print_timings(ctx_llava->ctx_llama); - - llava_image_embed_free(image_embed); - llava_free(ctx_llava); return 0; }