mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
examples: support LLaVA v1.5 (multimodal model) (#3436)
* WIP: start implementing LLaVA * rm scratch buf for now, will revert after cleanup * LLaVA image encoder is working. will combine with llama * Add llava inference code, but it's buggy. debugging * LLaVA is working e2e, needs to optimize memory allocation + cleanup * Use ggml_allocr + rm unnecessary code * fix: crlf -> lf * fix: new line at EoF * fix: trailing whitespace * Add readme * Update readme * Some cleanup * Are you happy editorconfig? * rm unused batch image preprocessing * rm unused import * fix: rm designated initializers * introduce pad-to-square mode for non-square images * are you happy editorconfig? * gitignore /llava * Handle cases where image file does not exist * add llava target to Makefile * add support for 13b model variant * Maybe seed is unlucky? * Check if apples are compared to apples * are you happy editorconfig? * Use temperature = 0.1 by default * command line: use gpt_params_parse() * minor * handle default n_predict * fix typo * llava : code formatting, rename files, fix compile warnings * do not use Wno-cast-qual for MSVC --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
9e24cc6e2e
commit
370359e5ba
1
.gitignore
vendored
1
.gitignore
vendored
@ -44,6 +44,7 @@ models-mnt
|
|||||||
/infill
|
/infill
|
||||||
/libllama.so
|
/libllama.so
|
||||||
/llama-bench
|
/llama-bench
|
||||||
|
/llava
|
||||||
/main
|
/main
|
||||||
/metal
|
/metal
|
||||||
/perplexity
|
/perplexity
|
||||||
|
5
Makefile
5
Makefile
@ -1,7 +1,7 @@
|
|||||||
# Define the default target now so that it is always the first target
|
# Define the default target now so that it is always the first target
|
||||||
BUILD_TARGETS = \
|
BUILD_TARGETS = \
|
||||||
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||||
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench baby-llama beam-search \
|
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench llava baby-llama beam-search \
|
||||||
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
|
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
@ -627,6 +627,9 @@ convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggm
|
|||||||
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
llava: examples/llava/llava.cpp examples/llava/llava-utils.h examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
|
||||||
|
|
||||||
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
|
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -384,6 +384,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.lora_base = argv[i];
|
params.lora_base = argv[i];
|
||||||
|
} else if (arg == "--mmproj") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.mmproj = argv[i];
|
||||||
|
} else if (arg == "--image") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.image = argv[i];
|
||||||
} else if (arg == "-i" || arg == "--interactive") {
|
} else if (arg == "-i" || arg == "--interactive") {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
} else if (arg == "--embedding") {
|
} else if (arg == "--embedding") {
|
||||||
@ -703,6 +715,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
||||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
||||||
|
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
|
||||||
if (llama_mlock_supported()) {
|
if (llama_mlock_supported()) {
|
||||||
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||||
}
|
}
|
||||||
|
@ -104,6 +104,10 @@ struct gpt_params {
|
|||||||
bool numa = false; // attempt optimizations that help on some NUMA systems
|
bool numa = false; // attempt optimizations that help on some NUMA systems
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
bool infill = false; // use infill mode
|
bool infill = false; // use infill mode
|
||||||
|
|
||||||
|
// multimodal models (see examples/llava)
|
||||||
|
std::string mmproj = ""; // path to multimodal projector
|
||||||
|
std::string image = ""; // path to an image file
|
||||||
};
|
};
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
8396
common/stb_image.h
Normal file
8396
common/stb_image.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -29,6 +29,7 @@ else()
|
|||||||
add_subdirectory(speculative)
|
add_subdirectory(speculative)
|
||||||
add_subdirectory(parallel)
|
add_subdirectory(parallel)
|
||||||
add_subdirectory(embd-input)
|
add_subdirectory(embd-input)
|
||||||
|
add_subdirectory(llava)
|
||||||
add_subdirectory(llama-bench)
|
add_subdirectory(llama-bench)
|
||||||
add_subdirectory(beam-search)
|
add_subdirectory(beam-search)
|
||||||
if (LLAMA_METAL)
|
if (LLAMA_METAL)
|
||||||
|
20
examples/llava/CMakeLists.txt
Normal file
20
examples/llava/CMakeLists.txt
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
set(TARGET clip)
|
||||||
|
add_library(${TARGET} clip.cpp clip.h)
|
||||||
|
install(TARGETS ${TARGET} LIBRARY)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
if (NOT MSVC)
|
||||||
|
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
|
endif()
|
||||||
|
if(TARGET BUILD_INFO)
|
||||||
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(TARGET llava)
|
||||||
|
add_executable(${TARGET} llava.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
if(TARGET BUILD_INFO)
|
||||||
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
|
endif()
|
57
examples/llava/README.md
Normal file
57
examples/llava/README.md
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# LLaVA
|
||||||
|
|
||||||
|
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants.
|
||||||
|
|
||||||
|
The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||||
|
and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||||
|
models are available.
|
||||||
|
|
||||||
|
After API is confirmed, more models will be supported / uploaded.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
Build with cmake or run `make llava` to build it.
|
||||||
|
|
||||||
|
After building, run: `./llava` to see the usage. For example:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./llava -m llava-v1.5-7b/ggml-model-q5_k.gguf --mmproj llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||||
|
|
||||||
|
## Model conversion
|
||||||
|
|
||||||
|
- Clone `llava-v15-7b`` and `clip-vit-large-patch14-336`` locally:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
|
||||||
|
|
||||||
|
git clone https://huggingface.co/openai/clip-vit-large-patch14-336
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python ./examples/llava/convert-image-encoder-to-gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python ./convert.py ../llava-v1.5-7b
|
||||||
|
```
|
||||||
|
|
||||||
|
Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory.
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- [ ] Support server mode.
|
||||||
|
- [ ] Support non-CPU backend for the image encoding part.
|
||||||
|
- [ ] Support different sampling methods.
|
||||||
|
- [ ] Support more model variants.
|
1062
examples/llava/clip.cpp
Normal file
1062
examples/llava/clip.cpp
Normal file
File diff suppressed because it is too large
Load Diff
73
examples/llava/clip.h
Normal file
73
examples/llava/clip.h
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
#ifndef CLIP_H
|
||||||
|
#define CLIP_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
struct clip_ctx;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct clip_vision_hparams {
|
||||||
|
int32_t image_size;
|
||||||
|
int32_t patch_size;
|
||||||
|
int32_t hidden_size;
|
||||||
|
int32_t n_intermediate;
|
||||||
|
int32_t projection_dim;
|
||||||
|
int32_t n_head;
|
||||||
|
int32_t n_layer;
|
||||||
|
float eps;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
|
||||||
|
|
||||||
|
void clip_free(struct clip_ctx * ctx);
|
||||||
|
|
||||||
|
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
||||||
|
int clip_n_patches(struct clip_ctx * ctx);
|
||||||
|
int clip_n_mmproj_embd(struct clip_ctx * ctx);
|
||||||
|
|
||||||
|
// RGB uint8 image
|
||||||
|
struct clip_image_u8 {
|
||||||
|
int nx;
|
||||||
|
int ny;
|
||||||
|
uint8_t * data;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
// RGB float32 image (NHWC)
|
||||||
|
// Memory layout: RGBRGBRGB...
|
||||||
|
struct clip_image_f32 {
|
||||||
|
int nx;
|
||||||
|
int ny;
|
||||||
|
float * data;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_image_u8_batch {
|
||||||
|
struct clip_image_u8 * data;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_image_f32_batch {
|
||||||
|
struct clip_image_f32 * data;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_image_u8 * make_clip_image_u8();
|
||||||
|
struct clip_image_f32 * make_clip_image_f32();
|
||||||
|
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||||
|
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
|
||||||
|
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
|
||||||
|
|
||||||
|
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
|
||||||
|
float * vec);
|
||||||
|
|
||||||
|
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // CLIP_H
|
250
examples/llava/convert-image-encoder-to-gguf.py
Normal file
250
examples/llava/convert-image-encoder-to-gguf.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from gguf import *
|
||||||
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
|
TEXT = "clip.text"
|
||||||
|
VISION = "clip.vision"
|
||||||
|
|
||||||
|
|
||||||
|
def k(raw_key: str, arch: str) -> str:
|
||||||
|
return raw_key.format(arch=arch)
|
||||||
|
|
||||||
|
|
||||||
|
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
|
||||||
|
if name in (
|
||||||
|
"logit_scale",
|
||||||
|
"text_model.embeddings.position_ids",
|
||||||
|
"vision_model.embeddings.position_ids",
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if name.startswith("v") and not has_vision:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if name.startswith("t") and not has_text:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_name(name: str) -> str:
|
||||||
|
if "projection" in name:
|
||||||
|
return name
|
||||||
|
|
||||||
|
if "mm_projector" in name:
|
||||||
|
return name.replace("model.mm_projector", "mm")
|
||||||
|
|
||||||
|
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = (
|
||||||
|
list(range(ord("!"), ord("~") + 1))
|
||||||
|
+ list(range(ord("¡"), ord("¬") + 1))
|
||||||
|
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
)
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8 + n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
|
||||||
|
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
||||||
|
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
||||||
|
ap.add_argument("--text-only", action="store_true", required=False,
|
||||||
|
help="Save a text-only model. It can't be used to encode images")
|
||||||
|
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||||
|
help="Save a vision-only model. It can't be used to encode texts")
|
||||||
|
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||||
|
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
||||||
|
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
||||||
|
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||||
|
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if args.text_only and args.vision_only:
|
||||||
|
print("--text-only and --image-only arguments cannot be specified at the same time.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.use_f32:
|
||||||
|
print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
|
||||||
|
|
||||||
|
# output in the same directory as the model if output_dir is None
|
||||||
|
dir_model = args.model_dir
|
||||||
|
|
||||||
|
|
||||||
|
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
tokens = [key for key in vocab]
|
||||||
|
|
||||||
|
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
v_hparams = config["vision_config"]
|
||||||
|
t_hparams = config["text_config"]
|
||||||
|
|
||||||
|
# possible data types
|
||||||
|
# ftype == 0 -> float32
|
||||||
|
# ftype == 1 -> float16
|
||||||
|
#
|
||||||
|
# map from ftype to string
|
||||||
|
ftype_str = ["f32", "f16"]
|
||||||
|
|
||||||
|
ftype = 1
|
||||||
|
if args.use_f32:
|
||||||
|
ftype = 0
|
||||||
|
|
||||||
|
|
||||||
|
model = CLIPModel.from_pretrained(dir_model)
|
||||||
|
processor = CLIPProcessor.from_pretrained(dir_model)
|
||||||
|
|
||||||
|
fname_middle = None
|
||||||
|
has_text_encoder = True
|
||||||
|
has_vision_encoder = True
|
||||||
|
has_llava_projector = False
|
||||||
|
if args.text_only:
|
||||||
|
fname_middle = "text-"
|
||||||
|
has_vision_encoder = False
|
||||||
|
elif args.vision_only:
|
||||||
|
fname_middle = "vision-"
|
||||||
|
has_text_encoder = False
|
||||||
|
elif args.llava_projector is not None:
|
||||||
|
fname_middle = "mmproj-"
|
||||||
|
has_text_encoder = False
|
||||||
|
has_llava_projector = True
|
||||||
|
else:
|
||||||
|
fname_middle = ""
|
||||||
|
|
||||||
|
output_dir = args.output_dir if args.output_dir is not None else dir_model
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
|
||||||
|
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
|
||||||
|
fout = GGUFWriter(path=fname_out, arch="clip")
|
||||||
|
|
||||||
|
fout.add_bool("clip.has_text_encoder", has_text_encoder)
|
||||||
|
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
|
||||||
|
fout.add_bool("clip.has_llava_projector", has_llava_projector)
|
||||||
|
fout.add_file_type(ftype)
|
||||||
|
model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
|
||||||
|
fout.add_name(model_name)
|
||||||
|
if args.text_only:
|
||||||
|
fout.add_description("text-only CLIP model")
|
||||||
|
elif args.vision_only and not has_llava_projector:
|
||||||
|
fout.add_description("vision-only CLIP model")
|
||||||
|
elif has_llava_projector:
|
||||||
|
fout.add_description("image encoder for LLaVA")
|
||||||
|
else:
|
||||||
|
fout.add_description("two-tower CLIP model")
|
||||||
|
|
||||||
|
if has_text_encoder:
|
||||||
|
# text_model hparams
|
||||||
|
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||||
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||||
|
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
||||||
|
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
|
||||||
|
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
||||||
|
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
||||||
|
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||||
|
fout.add_token_list(tokens)
|
||||||
|
|
||||||
|
if has_vision_encoder:
|
||||||
|
# vision_model hparams
|
||||||
|
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||||
|
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||||
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||||
|
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
||||||
|
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
|
||||||
|
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||||
|
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||||
|
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||||
|
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
|
|
||||||
|
image_mean = processor.image_processor.image_mean if args.image_mean is None else args.image_mean
|
||||||
|
image_std = processor.image_processor.image_std if args.image_std is None else args.image_std
|
||||||
|
fout.add_array("clip.vision.image_mean", image_mean)
|
||||||
|
fout.add_array("clip.vision.image_std", image_std)
|
||||||
|
|
||||||
|
use_gelu = v_hparams["hidden_act"] == "gelu"
|
||||||
|
fout.add_bool("clip.use_gelu", use_gelu)
|
||||||
|
|
||||||
|
|
||||||
|
if has_llava_projector:
|
||||||
|
model.vision_model.encoder.layers.pop(-1)
|
||||||
|
projector = torch.load(args.llava_projector)
|
||||||
|
for name, data in projector.items():
|
||||||
|
name = get_tensor_name(name)
|
||||||
|
if data.ndim == 2:
|
||||||
|
data = data.squeeze().numpy().astype(np.float16)
|
||||||
|
else:
|
||||||
|
data = data.squeeze().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
fout.add_tensor(name, data)
|
||||||
|
|
||||||
|
print("Projector tensors added\n")
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
for name, data in state_dict.items():
|
||||||
|
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
|
||||||
|
# we don't need this
|
||||||
|
print(f"skipping parameter: {name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = get_tensor_name(name)
|
||||||
|
data = data.squeeze().numpy()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
|
||||||
|
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||||
|
ftype_cur = 0
|
||||||
|
if n_dims == 4:
|
||||||
|
print(f"tensor {name} is always saved in f16")
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
ftype_cur = 1
|
||||||
|
elif ftype == 1:
|
||||||
|
if name[-7:] == ".weight" and n_dims == 2:
|
||||||
|
print(" Converting to float16")
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
ftype_cur = 1
|
||||||
|
else:
|
||||||
|
print(" Converting to float32")
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
ftype_cur = 0
|
||||||
|
else:
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
print(" Converting to float32")
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
ftype_cur = 0
|
||||||
|
|
||||||
|
print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
|
||||||
|
fout.add_tensor(name, data)
|
||||||
|
|
||||||
|
|
||||||
|
fout.write_header_to_file()
|
||||||
|
fout.write_kv_data_to_file()
|
||||||
|
fout.write_tensors_to_file()
|
||||||
|
fout.close()
|
||||||
|
|
||||||
|
print("Done. Output file: " + fname_out)
|
30
examples/llava/llava-surgery.py
Normal file
30
examples/llava/llava-surgery.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
# find the model part that includes the the multimodal projector weights
|
||||||
|
path = sorted(glob.glob(f"{args.model}/pytorch_model*.bin"))[-1]
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
|
||||||
|
# get a list of mm tensor names
|
||||||
|
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
|
||||||
|
|
||||||
|
# store these tensors in a new dictionary and torch.save them
|
||||||
|
projector = {name: checkpoint[name] for name in mm_tensors}
|
||||||
|
torch.save(projector, f"{args.model}/llava.projector")
|
||||||
|
|
||||||
|
# remove these tensors from the checkpoint and save it again
|
||||||
|
for name in mm_tensors:
|
||||||
|
del checkpoint[name]
|
||||||
|
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
||||||
|
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
145
examples/llava/llava-utils.h
Normal file
145
examples/llava/llava-utils.h
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// this one and clip lib will be eventually merged to a single lib, let's keep it this way for now
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
|
||||||
|
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i += n_batch) {
|
||||||
|
int n_eval = N - i;
|
||||||
|
if (n_eval > n_batch) {
|
||||||
|
n_eval = n_batch;
|
||||||
|
}
|
||||||
|
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||||
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*n_past += n_eval;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
|
||||||
|
int N = (int) tokens.size();
|
||||||
|
for (int i = 0; i < N; i += n_batch) {
|
||||||
|
int n_eval = (int) tokens.size() - i;
|
||||||
|
if (n_eval > n_batch) {
|
||||||
|
n_eval = n_batch;
|
||||||
|
}
|
||||||
|
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*n_past += n_eval;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
tokens.push_back(id);
|
||||||
|
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){
|
||||||
|
std::string str2 = str;
|
||||||
|
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
|
||||||
|
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: use common/sampling.h
|
||||||
|
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
|
||||||
|
// out of user input, sample next token
|
||||||
|
const float temp = params.sampling_params.temp;
|
||||||
|
const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k;
|
||||||
|
const float top_p = params.sampling_params.top_p;
|
||||||
|
const float tfs_z = params.sampling_params.tfs_z;
|
||||||
|
const float typical_p = params.sampling_params.typical_p;
|
||||||
|
// const int32_t repeat_last_n = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n;
|
||||||
|
// const float repeat_penalty = params.sampling_params.repeat_penalty;
|
||||||
|
// const float alpha_presence = params.sampling_params.presence_penalty;
|
||||||
|
// const float alpha_frequency = params.sampling_params.frequency_penalty;
|
||||||
|
const int mirostat = params.sampling_params.mirostat;
|
||||||
|
const float mirostat_tau = params.sampling_params.mirostat_tau;
|
||||||
|
const float mirostat_eta = params.sampling_params.mirostat_eta;
|
||||||
|
// const bool penalize_nl = params.sampling_params.penalize_nl;
|
||||||
|
|
||||||
|
llama_token id = 0;
|
||||||
|
{
|
||||||
|
auto logits = llama_get_logits(ctx_llama);
|
||||||
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
|
// Apply params.logit_bias map
|
||||||
|
for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) {
|
||||||
|
logits[it->first] += it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
// TODO: Apply penalties
|
||||||
|
// float nl_logit = logits[llama_token_nl(ctx)];
|
||||||
|
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||||
|
// llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||||
|
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
// last_n_repeat, repeat_penalty);
|
||||||
|
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
||||||
|
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
// last_n_repeat, alpha_frequency, alpha_presence);
|
||||||
|
// if (!penalize_nl) {
|
||||||
|
// logits[llama_token_nl(ctx)] = nl_logit;
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (temp <= 0) {
|
||||||
|
// Greedy sampling
|
||||||
|
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
|
||||||
|
} else {
|
||||||
|
if (mirostat == 1) {
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
const int mirostat_m = 100;
|
||||||
|
llama_sample_temp(ctx_llama, &candidates_p, temp);
|
||||||
|
id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||||
|
} else if (mirostat == 2) {
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
llama_sample_temp(ctx_llama, &candidates_p, temp);
|
||||||
|
id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
|
} else {
|
||||||
|
// Temperature sampling
|
||||||
|
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
|
||||||
|
llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1);
|
||||||
|
llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1);
|
||||||
|
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
|
||||||
|
llama_sample_temp(ctx_llama, &candidates_p, temp);
|
||||||
|
id = llama_sample_token(ctx_llama, &candidates_p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
|
||||||
|
int id = sample_id(ctx_llama, params);
|
||||||
|
static std::string ret;
|
||||||
|
if (id == llama_token_eos(ctx_llama)) {
|
||||||
|
ret = "</s>";
|
||||||
|
} else {
|
||||||
|
ret = llama_token_to_piece(ctx_llama, id);
|
||||||
|
}
|
||||||
|
eval_id(ctx_llama, id, n_past);
|
||||||
|
return ret.c_str();
|
||||||
|
}
|
156
examples/llava/llava.cpp
Normal file
156
examples/llava/llava.cpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#include "clip.h"
|
||||||
|
#include "llava-utils.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||||
|
printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||||
|
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
ggml_time_init();
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
|
show_additional_info(argc, argv);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.mmproj.empty() || params.image.empty()) {
|
||||||
|
gpt_print_usage(argc, argv, params);
|
||||||
|
show_additional_info(argc, argv);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * clip_path = params.mmproj.c_str();
|
||||||
|
const char * img_path = params.image.c_str();
|
||||||
|
|
||||||
|
if (params.prompt.empty()) {
|
||||||
|
params.prompt = "describe the image in detail.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||||
|
|
||||||
|
// load and preprocess the image
|
||||||
|
clip_image_u8 img;
|
||||||
|
clip_image_f32 img_res;
|
||||||
|
|
||||||
|
if (!clip_image_load_from_file(img_path, &img)) {
|
||||||
|
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
||||||
|
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true)) {
|
||||||
|
fprintf(stderr, "%s: unable to preprocess %s\n", __func__, img_path);
|
||||||
|
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_img_pos = clip_n_patches(ctx_clip);
|
||||||
|
int n_img_embd = clip_n_mmproj_embd(ctx_clip);
|
||||||
|
|
||||||
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
|
|
||||||
|
if (!image_embd) {
|
||||||
|
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||||
|
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
|
||||||
|
fprintf(stderr, "Unable to encode image\n");
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||||
|
|
||||||
|
// we get the embeddings, free up the memory required for CLIP
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
|
||||||
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||||
|
if (model == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
|
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
|
||||||
|
ctx_params.n_threads = params.n_threads;
|
||||||
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
|
||||||
|
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
|
if (ctx_llama == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||||
|
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
if (n_img_embd != n_llama_embd) {
|
||||||
|
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
|
||||||
|
|
||||||
|
llama_free(ctx_llama);
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_backend_free();
|
||||||
|
free(image_embd);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// process the prompt
|
||||||
|
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||||
|
|
||||||
|
int n_past = 0;
|
||||||
|
|
||||||
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
|
|
||||||
|
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||||
|
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
||||||
|
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||||
|
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
||||||
|
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
||||||
|
|
||||||
|
// generate the response
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
|
const char * tmp = sample(ctx_llama, params, &n_past);
|
||||||
|
if (strcmp(tmp, "</s>") == 0) break;
|
||||||
|
|
||||||
|
printf("%s", tmp);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
{
|
||||||
|
const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
||||||
|
|
||||||
|
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_print_timings(ctx_llama);
|
||||||
|
|
||||||
|
llama_free(ctx_llama);
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_backend_free();
|
||||||
|
free(image_embd);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
2
ggml.c
2
ggml.c
@ -14428,7 +14428,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
|
|||||||
int64_t t0 = ggml_perf_time_us();
|
int64_t t0 = ggml_perf_time_us();
|
||||||
UNUSED(t0);
|
UNUSED(t0);
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
Loading…
Reference in New Issue
Block a user