Merge branch 'prepare-PR-of-minicpm-v2.5' into prepare-PR

This commit is contained in:
tc-mb 2024-05-29 02:49:59 +08:00 committed by GitHub
commit 8767ce29cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 3457 additions and 73 deletions

2
.gitignore vendored
View File

@ -60,6 +60,8 @@ models-mnt
/libllama.so /libllama.so
/llama-bench /llama-bench
/llava-cli /llava-cli
/minicpmv-cli
/openbmb
/lookahead /lookahead
/lookup /lookup
/lookup-create /lookup-create

View File

@ -1,7 +1,7 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = \ BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama beam-search \ simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli minicpmv-cli baby-llama beam-search \
retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
@ -862,6 +862,13 @@ llava-cli: examples/llava/llava-cli.cpp examples/llava/clip.h examples/llava/cli
$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS)
minicpmv-cli: examples/minicpmv/minicpmv-cli.cpp examples/minicpmv/clip.h examples/minicpmv/clip.cpp examples/minicpmv/minicpmv.h examples/minicpmv/minicpmv.cpp examples/minicpmv/minicpmv_wrapper.h examples/minicpmv/minicpmv_wrapper.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) -c examples/minicpmv/clip.cpp -o $(call GET_OBJ_FILE, examples/minicpmv/clip.cpp) -Wno-cast-qual
$(CXX) $(CXXFLAGS) -c examples/minicpmv/minicpmv.cpp -o $(call GET_OBJ_FILE, examples/minicpmv/minicpmv.cpp)
$(CXX) $(CXXFLAGS) -c examples/minicpmv/minicpmv_wrapper.cpp -o $(call GET_OBJ_FILE, examples/minicpmv/minicpmv_wrapper.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/minicpmv/clip.cpp examples/minicpmv/minicpmv.cpp examples/minicpmv/minicpmv_wrapper.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/minicpmv/clip.cpp) $(call GET_OBJ_FILE, examples/minicpmv/minicpmv.cpp) $(call GET_OBJ_FILE, examples/minicpmv/minicpmv_wrapper.cpp) -o $@ $(LDFLAGS)
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View File

@ -675,44 +675,6 @@ class GPTNeoXModel(Model):
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
tensors: list[tuple[str, Tensor]] = []
if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
# Map bloom-style qkv_linear to gpt-style qkv_linear
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
data_torch = torch.cat(
(
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
),
dim=0,
)
logger.info("re-format attention.linear_qkv.weight")
elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
data_torch = torch.cat(
(
qkv_bias[:, 0, :].reshape((n_embed,)),
qkv_bias[:, 1, :].reshape((n_embed,)),
qkv_bias[:, 2, :].reshape((n_embed,)),
),
dim=0,
)
logger.info("re-format attention.linear_qkv.bias")
tensors.append((self.map_tensor_name(name), data_torch))
return tensors
@Model.register("BloomForCausalLM") @Model.register("BloomForCausalLM")
class BloomModel(Model): class BloomModel(Model):

View File

@ -26,6 +26,7 @@ else()
add_subdirectory(infill) add_subdirectory(infill)
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(llava) add_subdirectory(llava)
add_subdirectory(minicpmv)
if (LLAMA_SYCL) if (LLAMA_SYCL)
add_subdirectory(sycl) add_subdirectory(sycl)
endif() endif()

View File

@ -0,0 +1,42 @@
add_library(minicpmv OBJECT
minicpmv.cpp
minicpmv.h
clip.cpp
clip.h
)
target_link_libraries(minicpmv PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(minicpmv PUBLIC .)
target_include_directories(minicpmv PUBLIC ../..)
target_include_directories(minicpmv PUBLIC ../../common)
target_compile_features(minicpmv PRIVATE cxx_std_11)
add_library(minicpmv_static STATIC $<TARGET_OBJECTS:minicpmv>)
if (BUILD_SHARED_LIBS)
set_target_properties(minicpmv PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(minicpmv PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(minicpmv_shared SHARED $<TARGET_OBJECTS:minicpmv>)
target_link_libraries(minicpmv_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS minicpmv_shared LIBRARY)
endif()
if (NOT MSVC)
target_compile_options(minicpmv PRIVATE -Wno-cast-qual) # stb_image.h
endif()
if(TARGET BUILD_INFO)
add_dependencies(minicpmv BUILD_INFO)
endif()
set(TARGET minicpmv-cli)
add_executable(minicpmv-cli minicpmv-cli.cpp)
install(TARGETS minicpmv-cli RUNTIME)
target_link_libraries(minicpmv-cli PRIVATE common minicpmv_wrapper minicpmv ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(minicpmv PRIVATE cxx_std_11)
add_library(minicpmv_wrapper OBJECT
minicpmv_wrapper.cpp
)
target_link_libraries(minicpmv_wrapper PRIVATE minicpmv ${CMAKE_THREAD_LIBS_INIT})

104
examples/minicpmv/README.md Normal file
View File

@ -0,0 +1,104 @@
## MiniCPM-Llama3-V 2.5
### Usage
Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) PyTorch model from huggingface to "MiniCPM-Llama3-V-2_5" folder.
Clone llama.cpp and checkout to branch `minicpm-v2.5`:
```bash
git clone -b minicpm-v2.5 https://github.com/OpenBMB/llama.cpp.git
cd llama.cpp
```
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us)
```bash
python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
python ./convert.py ../MiniCPM-Llama3-V-2_5/model --outtype f16 --vocab-type bpe
# quantize int4 version
./quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Build for Linux or Mac
```bash
make
make minicpmv-cli
```
Inference on Linux or Mac
```
# run f16 version
./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run quantized int4 version
./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# or run in interactive mode
./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
```
### Android
#### Build on Android device using Termux
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
Install tools in Termux:
```
apt update && apt upgrade -y
apt install git make cmake
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
#### Building the Project using Android NDK
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
```bash
mkdir build-android
cd build-android
export NDK=/your_ndk_path
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
make
```
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
```
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
$cd /data/data/com.termux/files/home/bin
$chmod +x ./*
```
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
```
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
```
Now, you can start chatting:
```
$cd /data/data/com.termux/files/home/bin
$./minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
```
### result
We use this command on Xiaomi 14 Pro, and the measured results.
```
$./minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 -t 6 --image xx.jpg -p "What is in the image?"
```
![alt text](assets/xiaomi14pro_test.jpeg)

Binary file not shown.

After

Width:  |  Height:  |  Size: 304 KiB

1870
examples/minicpmv/clip.cpp Normal file

File diff suppressed because it is too large Load Diff

85
examples/minicpmv/clip.h Normal file
View File

@ -0,0 +1,85 @@
#ifndef CLIP_H
#define CLIP_H
#include <stddef.h>
#include <stdint.h>
#include <utility>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define CLIP_API __declspec(dllexport)
# else
# define CLIP_API __declspec(dllimport)
# endif
# else
# define CLIP_API __attribute__ ((visibility ("default")))
# endif
#else
# define CLIP_API
#endif
struct clip_ctx;
#ifdef __cplusplus
extern "C" {
#endif
struct clip_ctx;
struct clip_image_u8_batch {
struct clip_image_u8 * data;
size_t size;
};
struct clip_image_f32_batch {
struct clip_image_f32 * data;
size_t size;
};
CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity, std::pair<int, int> load_image_size);
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
// TODO: should be enum, not string
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
static void normalize_image_u8_to_f32(struct clip_ctx * ctx, const clip_image_u8* src, clip_image_f32* dst);
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec, std::pair<int, int> load_image_size);
CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec, std::pair<int, int> load_image_size);
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
#ifdef __cplusplus
}
#endif
#endif // CLIP_H

View File

@ -0,0 +1,158 @@
#include "ggml.h"
#include "log.h"
#include "common.h"
#include "clip.h"
#include "minicpmv.h"
#include "minicpmv_wrapper.h"
#include "llama.h"
#include <cstdio>
#include <cstdlib>
#include <vector>
static void show_additional_info(int /*argc*/, char ** argv) {
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
LOG_TEE("%s", text);
}
struct minicpmv_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){
auto image_embed_slices = minicpmv_image_embed(params, fname);
if (!image_embed_slices[0][0]) {
std::cerr << "error: failed to load image " << fname << ". Terminating\n\n";
return NULL;
}
// process the prompt
if (params->prompt.empty() && params->interactive == false) {
LOG_TEE("prompt should be given or interactive mode should be on");
return NULL;
}
auto model = llava_init(params);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__);
return NULL;
}
const int64_t t_llava_init_start_us = ggml_time_us();
auto ctx_llava = llava_init_context(params, model);
const int64_t t_llava_init_end_us = ggml_time_us();
float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0;
LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms);
const int64_t t_process_image_start_us = ggml_time_us();
process_image(ctx_llava, image_embed_slices, params, n_past);
const int64_t t_process_image_end_us = ggml_time_us();
float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0;
LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms);
llava_image_embed_free_slice(image_embed_slices);
return ctx_llava;
}
struct llama_sampling_context * llama_init(struct minicpmv_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt;
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
// generate the response
LOG_TEE("\n");
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
return ctx_sampling;
}
const char * llama_loop(struct minicpmv_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
return tmp;
}
int main(int argc, char ** argv) {
ggml_time_init();
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
show_additional_info(argc, argv);
return 1;
}
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("llava", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
llama_log_set(llama_log_callback_logTee, nullptr);
#endif // LOG_DISABLE_LOGS
if (params.mmproj.empty() || (params.image.empty())) {
gpt_params_print_usage(argc, argv, params);
show_additional_info(argc, argv);
return 1;
}
for (auto & image : params.image) {
int n_past = 0;
auto ctx_llava = minicpmv_init(&params, image, n_past);
if (!params.prompt.empty()) {
LOG_TEE("<user>%s\n", params.prompt.c_str());
LOG_TEE("<assistant>");
auto ctx_sampling = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
bool have_tmp = false;
for (int i = 0; i < max_tgt_len; i++) {
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0){
if(!have_tmp)continue;
else break;
}
if (strstr(tmp, "###")) break; // Yi-VL behavior
have_tmp = true;
printf("%s", tmp);
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);
}
llama_sampling_free(ctx_sampling);
}else {
while (true) {
LOG_TEE("<user>");
std::string prompt;
std::getline(std::cin, prompt);
LOG_TEE("<assistant>");
auto ctx_sampling = llama_init(ctx_llava, &params, prompt, n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);// mistral llava-1.6
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);
}
llama_sampling_free(ctx_sampling);
}
}
printf("\n");
llama_print_timings(ctx_llava->ctx_llama);
ctx_llava->model = NULL;
llava_free(ctx_llava);
}
return 0;
}

View File

@ -0,0 +1,382 @@
import argparse
import os
import json
import re
import torch
import numpy as np
from gguf import *
import timm
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
TEXT = "clip.text"
VISION = "clip.vision"
def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool:
if name in (
"logit_scale",
"text_model.embeddings.position_ids",
"vision_model.embeddings.position_ids",
):
return True
if has_minicpmv and name in ["visual_projection.weight"]:
return True
if name.startswith("v") and not has_vision:
return True
if name.startswith("t") and not has_text:
return True
return False
def get_tensor_name(name: str) -> str:
if "projection" in name:
return name
if "mm_projector" in name:
name = name.replace("model.mm_projector", "mm")
name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
return name
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
# with proper
args = ap.parse_args()
if args.text_only and args.vision_only:
print("--text-only and --image-only arguments cannot be specified at the same time.")
exit(1)
if args.use_f32:
print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
vocab = None
tokens = None
else:
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
vocab = json.load(f)
tokens = [key for key in vocab]
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1
if args.use_f32:
ftype = 0
# if args.clip_model_is_vision or args.clip_model_is_openclip:
# model = CLIPVisionModel.from_pretrained(dir_model)
# processor = None
# else:
# model = CLIPModel.from_pretrained(dir_model)
# processor = CLIPProcessor.from_pretrained(dir_model)
default_vision_config = {
"hidden_size": 1152,
"image_size": 980,
"intermediate_size": 4304,
"model_type": "idefics2",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
}
vision_config = Idefics2VisionConfig(**default_vision_config)
model = Idefics2VisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:
# model.attn_pool = torch.nn.Identity()
# model.blocks = model.blocks[:-1]
model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip")))
fname_middle = None
has_text_encoder = True
has_vision_encoder = True
has_minicpmv_projector = False
if args.text_only:
fname_middle = "text-"
has_vision_encoder = False
elif args.minicpmv_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_minicpmv_projector = True
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
else:
fname_middle = ""
output_dir = args.output_dir if args.output_dir is not None else dir_model
os.makedirs(output_dir, exist_ok=True)
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
fout = GGUFWriter(path=fname_out, arch="clip")
fout.add_bool("clip.has_text_encoder", has_text_encoder)
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector)
fout.add_file_type(ftype)
if args.text_only:
fout.add_description("text-only CLIP model")
elif args.vision_only and not has_minicpmv_projector:
fout.add_description("vision-only CLIP model")
elif has_minicpmv_projector:
fout.add_description("image encoder for MiniCPM-V")
# add projector type
fout.add_string("clip.projector_type", "resampler")
else:
fout.add_description("two-tower CLIP model")
if has_vision_encoder:
# vision_model hparams
fout.add_uint32("clip.vision.image_size", 448)
fout.add_uint32("clip.vision.patch_size", 14)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), 1152)
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 4304)
fout.add_uint32("clip.vision.projection_dim", 0)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
block_count = 26
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
else:
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
image_std = args.image_std if args.image_std is not None else default_image_std
fout.add_array("clip.vision.image_mean", image_mean)
fout.add_array("clip.vision.image_std", image_std)
use_gelu = True
fout.add_bool("clip.use_gelu", use_gelu)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def _replace_name_resampler(s, v):
if re.match("resampler.pos_embed", s):
return {
s: v,
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
}
if re.match("resampler.proj", s):
return {
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
}
if re.match("resampler.attn.in_proj_.*", s):
return {
re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0],
re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1],
re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2],
}
return {s: v}
if has_minicpmv_projector:
projector = torch.load(args.minicpmv_projector)
new_state_dict = {}
for k, v in projector.items():
kvs = _replace_name_resampler(k, v)
for nk, nv in kvs.items():
new_state_dict[nk] = nv
projector = new_state_dict
for name, data in projector.items():
name = get_tensor_name(name)
data = data.squeeze().numpy()
n_dims = len(data.shape)
if ftype == 1:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
fout.add_tensor(name, data)
print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
print("Projector tensors added\n")
def _replace_name(s, v):
s = "vision_model." + s
if re.match("vision_model.embeddings.position_embedding", s):
v = v.unsqueeze(0)
return {s: v}
return {s: v}
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
kvs = _replace_name(k, v)
for nk, nv in kvs.items():
new_state_dict[nk] = nv
state_dict = new_state_dict
for name, data in state_dict.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector):
# we don't need this
print(f"skipping parameter: {name}")
continue
name = get_tensor_name(name)
data = data.squeeze().numpy()
n_dims = len(data.shape)
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if n_dims == 4:
print(f"tensor {name} is always saved in f16")
data = data.astype(np.float16)
ftype_cur = 1
elif ftype == 1:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
fout.add_tensor(name, data)
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
print("Done. Output file: " + fname_out)

View File

@ -0,0 +1,48 @@
import argparse
import glob
import os
import torch
from transformers import AutoModel, AutoTokenizer
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
checkpoint = model.state_dict()
# get a list of mm tensor names
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")]
# store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/minicpmv.projector")
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")]
if len(clip_tensors) > 0:
clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors}
torch.save(clip, f"{args.model}/minicpmv.clip")
# added tokens should be removed to be able to convert Mistral models
if os.path.exists(f"{args.model}/added_tokens.json"):
with open(f"{args.model}/added_tokens.json", "w") as f:
f.write("{}\n")
config = model.llm.config
config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5"
config.auto_map = {
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
"AutoModel": "modeling_minicpm.MiniCPMModel",
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
}
model.llm.save_pretrained(f"{args.model}/model")
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
tok.save_pretrained(f"{args.model}/model")
# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py")
print("Done!")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.")

View File

@ -0,0 +1,450 @@
#include "clip.h"
#include "common.h"
#include "llama.h"
#include "minicpmv.h"
#include "base64.hpp"
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <numeric>
// RGB uint8 image
struct clip_image_u8 {
int nx;
int ny;
std::vector<uint8_t> buf;
};
// RGB float32 image (NHWC)
// Memory layout: RGBRGBRGB...
struct clip_image_f32 {
int nx;
int ny;
std::vector<float> buf;
};
struct clip_image_grid_shape {
int first;
int second;
};
static bool encode_image_with_clip_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
// std::vector<clip_image_f32*> img_res_v;
// format VectN x H x W x RGB (N x 448 x 448 x 3)
clip_image_f32 * img_res_v = clip_image_f32_init();
std::pair<int, int> load_image_size;
load_image_size.first = img->nx;
load_image_size.second = img->ny;
normalize_image_u8_to_f32(ctx_clip, img, img_res_v);
const int64_t t_img_enc_start_us = ggml_time_us();
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
LOG_TEE("\n%s: mm_patch_merge_type is %s.\n", __func__, mm_patch_merge_type);
*n_img_pos = clip_n_patches(ctx_clip);
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res_v, image_embd, load_image_size); // image_embd shape is 96 x 4096
if (!encoded) {
LOG_TEE("Unable to encode image\n");
return false;
}
LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
const int64_t t_img_enc_end_us = ggml_time_us();
float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
LOG_TEE("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos);
return true;
}
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) {
LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
return false;
}
return true;
}
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6);
if (!image_embd) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
return false;
}
int n_img_pos;
if (!encode_image_with_clip_uhd(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
LOG_TEE("%s: cannot encode image, aborting\n", __func__);
free(image_embd);
return false;
}
*image_embd_out = image_embd;
*n_img_pos_out = n_img_pos;
return true;
}
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
int ensure_divide(int length, int patch_size) {
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
}
std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.first;
int height = original_size.second;
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
float r = static_cast<float>(width) / height;
height = static_cast<int>(scale_resolution / std::sqrt(r));
width = static_cast<int>(height * r);
}
int best_width = ensure_divide(width, patch_size);
int best_height = ensure_divide(height, patch_size);
return std::make_pair(best_width, best_height);
}
inline float clip(float x, float lower, float upper) {
return std::max(lower, std::min(x, upper));
}
std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width, height;
std::tie(width, height) = original_size;
int grid_x, grid_y;
std::tie(grid_x, grid_y) = grid;
int refine_width = ensure_divide(width, grid_x);
int refine_height = ensure_divide(height, grid_y);
int grid_width = refine_width / grid_x;
int grid_height = refine_height / grid_y;
// auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
int best_grid_width, best_grid_height;
std::tie(best_grid_width, best_grid_height) = best_grid_size;
// std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
return refine_size;
}
static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
const int nx = img.nx;
const int ny = img.ny;
dst.nx = target_width;
dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height);
float Cc;
float C[5];
float d0, d2, d3, a0, a1, a2, a3;
int i, j, k, jj;
int x, y;
float dx, dy;
float tx, ty;
tx = (float)nx / (float)target_width;
ty = (float)ny / (float)target_height;
// Bicubic interpolation; adapted from ViT.cpp, inspired from :
// -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
// -> https://en.wikipedia.org/wiki/Bicubic_interpolation
for (i = 0; i < target_height; i++) {
for (j = 0; j < target_width; j++) {
x = (int)(tx * j);
y = (int)(ty * i);
dx = tx * j - x;
dy = ty * i - y;
for (k = 0; k < 3; k++) {
for (jj = 0; jj <= 3; jj++) {
d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
d0 = C[0] - C[1];
d2 = C[2] - C[1];
d3 = C[3] - C[1];
a0 = C[1];
a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
}
}
}
}
return true;
}
// inspired from LLaVA-UHD:
// -> https://arxiv.org/pdf/2403.11703
// -> https://github.com/thunlp/LLaVA-UHD
// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14, const bool never_split=false) {
const std::pair<int, int> original_size={img->nx,img->ny};
const int original_width = img->nx;
const int original_height = img->ny;
const float log_ratio = log(1.0*original_width/original_height); //
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
const int multiple = fmin(ceil(ratio), max_slice_nums);
std::vector<std::vector<clip_image_u8 *>> images;
LOG_TEE("%s: multiple %d\n", __func__, multiple);
images.push_back(std::vector<clip_image_u8 *>());
if(multiple <= 1){
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
clip_image_u8 *source_image = clip_image_u8_init();
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
images[images.size()-1].push_back(source_image);
}
else if(multiple > 1){
std::vector<int> candidate_split_grids_nums;
for (int i : {multiple - 1, multiple, multiple + 1}) {
if (i == 1 || i > max_slice_nums) {
continue;
}
candidate_split_grids_nums.push_back(i);
}
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
clip_image_u8 *source_image = clip_image_u8_init();
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
images[images.size()-1].push_back(source_image);
std::vector<std::pair<int, int>> candidate_grids;
for (int split_grids_nums : candidate_split_grids_nums) {
int m = 1;
while (m <= split_grids_nums) {
if (split_grids_nums % m == 0) {
candidate_grids.emplace_back(m, split_grids_nums / m);
}
++m;
}
}
std::pair<int, int> best_grid{1, 1};
float min_error = std::numeric_limits<float>::infinity();
for (const auto& grid : candidate_grids) {
float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
if (error < min_error) {
best_grid = grid;
min_error = error;
}
}
LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
clip_image_u8 *refine_image = clip_image_u8_init();
bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
LOG_TEE("%s: refine_image_size: %d %d; best_grid: %d %d\n", __func__, refine_image->nx, refine_image->ny, best_grid.first, best_grid.second);
// split_to_patches
int width = refine_image->nx;
int height = refine_image->ny;
int grid_x = int(width / best_grid.first);
int grid_y = int(height / best_grid.second);
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
images.push_back(std::vector<clip_image_u8 *>());
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
clip_image_u8 * patch = clip_image_u8_init();
patch->nx = grid_x;
patch->ny = grid_y;
patch->buf.resize(3 * patch->nx * patch->ny);
for (int y = patches_i; y < patches_i + grid_y; ++y) {
for (int x = patches_j; x < patches_j + grid_x; ++x) {
const int i = 3 * (y * refine_image->nx + x);
const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j));
patch->buf[j] = refine_image->buf[i];
patch->buf[j+1] = refine_image->buf[i+1];
patch->buf[j+2] = refine_image->buf[i+2];
}
}
images[images.size()-1].push_back(patch);
}
}
}
return images;
}
std::vector<std::vector<struct llava_image_embed *>> llava_image_embed_make_with_bytes_uhd(struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img) {
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img);
for (size_t i = 0; i < imgs.size(); ++i){
for (size_t j = 0; j < imgs[i].size(); ++j) {
LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
}
}
std::vector<std::vector<llava_image_embed *>> results;
for (size_t i = 0; i < imgs.size(); ++i){
results.push_back(std::vector<llava_image_embed *>());
for (size_t j = 0; j < imgs[i].size(); ++j) {
float* image_embed = NULL;
int n_image_pos = 0;
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, imgs[i][j], &image_embed, &n_image_pos);
if (!image_embed_result) {
LOG_TEE("%s: coulnd't embed the image\n", __func__);
return std::vector<std::vector<struct llava_image_embed *>>();
}
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
result->embed = image_embed;
result->n_image_pos = n_image_pos;
results[i].push_back(result);
}
}
return results;
}
static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) {
auto file = fopen(path, "rb");
if (file == NULL) {
LOG_TEE("%s: can't read file %s\n", __func__, path);
return false;
}
fseek(file, 0, SEEK_END);
auto fileSize = ftell(file);
fseek(file, 0, SEEK_SET);
auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
if (buffer == NULL) {
LOG_TEE("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
perror("Memory allocation error");
fclose(file);
return false;
}
errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
}
fclose(file); // Close the file
*bytesOut = buffer;
*sizeOut = fileSize;
return true;
}
bool llava_image_embed_make_with_clip_img_ollama(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
auto image_embed_slices = llava_image_embed_make_with_bytes_uhd(ctx_clip, n_threads, img);
if (!image_embed_slices[0][0]){
LOG_TEE("%s: failed to embeding image\n", __func__);
return false;
}
std::string fname = "./examples/minicpm-v2.5/slice_token_for_ollama.raw";
unsigned char* slice_token;
long image_bytes_length;
auto loaded = load_file_to_bytes(fname.c_str(), &slice_token, &image_bytes_length);
if (!loaded) {
LOG_TEE("%s: failed to load %s\n", __func__, fname.c_str());
return false;
}
float * all_image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*61);
int all_n_img_pos=0;
int token_len = clip_n_mmproj_embd(ctx_clip)*sizeof(float);
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token, token_len);
std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[0][0]->embed, 96*token_len);
all_n_img_pos+=clip_n_patches(ctx_clip);
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len, token_len);
if (image_embed_slices.size() > 1) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*2, token_len);
for (size_t i = 1; i < image_embed_slices.size(); ++i) {
for (size_t j = 0; j < image_embed_slices[i].size(); ++j) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token, token_len);
std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[i][j]->embed, 96*token_len);
all_n_img_pos+=clip_n_patches(ctx_clip);
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len, token_len);
if (j == image_embed_slices[i].size() - 1) {
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*4, token_len);
}
}
}
std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*3, token_len);
}
*image_embd_out = all_image_embd;
*n_img_pos_out = all_n_img_pos;
return true;
}
std::vector<std::vector<struct llava_image_embed *>> llava_image_embed_make_with_filename_uhd(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
unsigned char* image_bytes;
long image_bytes_length;
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
if (!loaded) {
LOG_TEE("%s: failed to load %s\n", __func__, image_path);
return std::vector<std::vector<struct llava_image_embed *>>();
}
clip_image_u8 * img = clip_image_u8_init();
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
clip_image_u8_free(img);
LOG_TEE("%s: can't load image from bytes, is it a valid image?", __func__);
return std::vector<std::vector<struct llava_image_embed *>>();
}
std::vector<std::vector<struct llava_image_embed *>> embeds = llava_image_embed_make_with_bytes_uhd(ctx_clip, n_threads, img);
clip_image_u8_free(img);
free(image_bytes);
return embeds;
}
void llava_image_embed_free_uhd(std::vector<std::vector<struct llava_image_embed *>> embed) {
for (size_t i = 0; i < embed.size(); ++i){
for (size_t j = 0; j < embed[i].size(); ++j){
free(embed[i][j]->embed);
free(embed[i][j]);
}
embed[i] = std::vector<struct llava_image_embed *>();
}
embed = std::vector<std::vector<struct llava_image_embed *>>();
}

View File

@ -0,0 +1,51 @@
#ifndef LLAVA_H
#define LLAVA_H
#include "ggml.h"
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define MINICPMV_API __declspec(dllexport)
# else
# define MINICPMV_API __declspec(dllimport)
# endif
# else
# define MINICPMV_API __attribute__ ((visibility ("default")))
# endif
#else
# define MINICPMV_API
#endif
struct clip_ctx;
#ifdef __cplusplus
extern "C" {
#endif
struct llava_image_embed {
float * embed;
int n_image_pos;
};
/** sanity check for clip <-> llava embed size match */
MINICPMV_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip);
MINICPMV_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
/** build an image embed from image file bytes */
MINICPMV_API std::vector<std::vector<struct llava_image_embed *>> llava_image_embed_make_with_bytes_uhd(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
/** build an image embed from a path to an image filename */
MINICPMV_API bool llava_image_embed_make_with_clip_img_ollama(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
MINICPMV_API std::vector<std::vector<struct llava_image_embed *>> llava_image_embed_make_with_filename_uhd(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
MINICPMV_API void llava_image_embed_free_uhd(std::vector<std::vector<struct llava_image_embed *>> embed);
/** free an embedding made with llava_image_embed_make_* */
/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
MINICPMV_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,147 @@
#include "ggml.h"
#include "common.h"
#include "clip.h"
#include "minicpmv.h"
#include "minicpmv_wrapper.h"
#include "llama.h"
#include <cstdio>
#include <cstdlib>
#include <vector>
struct llama_model * llava_init(gpt_params * params) {
llama_backend_init();
llama_numa_init(params->numa);
llama_model_params model_params = llama_model_params_from_gpt_params(*params);
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n" , __func__);
return NULL;
}
return model;
}
struct minicpmv_context * llava_init_context(gpt_params * params, llama_model * model) {
const char * clip_path = params->mmproj.c_str();
auto prompt = params->prompt;
if (prompt.empty()) {
prompt = "describe the image in detail.";
}
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL) {
LOG_TEE("%s: error: failed to create the llama_context\n" , __func__);
return NULL;
}
auto ctx_llava = (struct minicpmv_context *)malloc(sizeof(minicpmv_context));
ctx_llava->ctx_llama = ctx_llama;
ctx_llava->model = model;
return ctx_llava;
}
void llava_free(struct minicpmv_context * ctx_llava) {
llama_free(ctx_llava->ctx_llama);
llama_free_model(ctx_llava->model);
llama_backend_free();
}
struct clip_ctx * clip_init_context(gpt_params * params) {
const char * clip_path = params->mmproj.c_str();
auto prompt = params->prompt;
if (prompt.empty()) {
prompt = "describe the image in detail.";
}
std::pair<int, int> load_image_size = std::make_pair(448, 448);
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1, load_image_size);
return ctx_clip;
}
std::vector<std::vector<struct llava_image_embed *>> minicpmv_image_embed(gpt_params * params, const std::string & fname){
auto ctx_clip = clip_init_context(params);
auto image_embed_and_slices = llava_image_embed_make_with_filename_slice(ctx_clip, params->n_threads, fname.c_str());
if (ctx_clip) {
clip_free(ctx_clip);
ctx_clip = NULL;
}
return image_embed_and_slices;
}
bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size();
for (int i = 0; i < N; i += n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}
*n_past += n_eval;
}
return true;
}
bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past);
}
bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
void process_image(struct minicpmv_context * ctx_llava, std::vector<std::vector<struct llava_image_embed *>> image_embed_slices, gpt_params * params, int &n_past) {
std::string system_prompt;
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, true);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[0][0], params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
if (image_embed_slices.size() > 1) {
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
for (size_t i = 1; i < image_embed_slices.size(); ++i) {
for (size_t j = 0; j < image_embed_slices[i].size(); ++j) {
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[i][j], params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
if (j == image_embed_slices[i].size() - 1) {
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
}
}
}
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
}
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
}
const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
}
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}

View File

@ -0,0 +1,49 @@
#ifndef MINICPMV_H
#define MINICPMV_H
#include "common.h"
#include "clip.h"
#include "minicpmv.h"
#include "llama.h"
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define MINICPMV_API __declspec(dllexport)
# else
# define MINICPMV_API __declspec(dllimport)
# endif
# else
# define MINICPMV_API __attribute__ ((visibility ("default")))
# endif
#else
# define MINICPMV_API
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct minicpmv_context {
struct llama_context * ctx_llama = NULL;
struct llama_model * model = NULL;
};
MINICPMV_API struct llama_model * llava_init(gpt_params * params);
MINICPMV_API struct minicpmv_context * llava_init_context(gpt_params * params, llama_model * model);
MINICPMV_API void llava_free(struct minicpmv_context * ctx_llava);
MINICPMV_API struct clip_ctx * clip_init_context(gpt_params * params);
MINICPMV_API std::vector<std::vector<struct llava_image_embed *>> minicpmv_image_embed(gpt_params * params, const std::string & fname);
MINICPMV_API bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past);
MINICPMV_API bool eval_id(struct llama_context * ctx_llama, int id, int * n_past);
MINICPMV_API bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos);
MINICPMV_API void process_image(struct minicpmv_context * ctx_llava, std::vector<std::vector<struct llava_image_embed *>> image_embed_slices, gpt_params * params, int &n_past);
MINICPMV_API const char * sample(struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_llama, int * n_past);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,3 @@
-r ../../requirements/requirements-convert.txt
pillow~=10.2.0
torch~=2.1.1

View File

@ -645,6 +645,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
], ],
MODEL_ARCH.MINICPM: [ MODEL_ARCH.MINICPM: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,

View File

@ -1768,24 +1768,17 @@ static llama_state g_state;
// available llama models // available llama models
enum e_model { enum e_model {
MODEL_UNKNOWN, MODEL_UNKNOWN,
MODEL_14M,
MODEL_17M, MODEL_17M,
MODEL_22M, MODEL_22M,
MODEL_33M, MODEL_33M,
MODEL_70M,
MODEL_109M, MODEL_109M,
MODEL_137M, MODEL_137M,
MODEL_160M,
MODEL_335M, MODEL_335M,
MODEL_410M,
MODEL_0_5B, MODEL_0_5B,
MODEL_1B, MODEL_1B,
MODEL_1_4B,
MODEL_2B, MODEL_2B,
MODEL_2_8B,
MODEL_3B, MODEL_3B,
MODEL_4B, MODEL_4B,
MODEL_6_9B,
MODEL_7B, MODEL_7B,
MODEL_8B, MODEL_8B,
MODEL_12B, MODEL_12B,
@ -1820,7 +1813,6 @@ static const size_t GiB = 1024*MiB;
struct llama_hparams { struct llama_hparams {
bool vocab_only; bool vocab_only;
bool rope_finetuned; bool rope_finetuned;
bool use_par_res;
uint32_t n_vocab; uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on uint32_t n_ctx_train; // context size the model was trained on
@ -2585,6 +2577,7 @@ static bool llama_kv_cache_init(
static bool llama_kv_cache_find_slot( static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
const struct llama_batch & batch) { const struct llama_batch & batch) {
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
if (cache.recurrent) { if (cache.recurrent) {
@ -2635,16 +2628,16 @@ static bool llama_kv_cache_find_slot(
} }
// otherwise, one cell per token. // otherwise, one cell per token.
if (n_tokens > cache.size) { if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false; return false;
} }
uint32_t n_tested = 0; uint32_t n_tested = 0;
while (true) { while (true) {
if (cache.head + n_tokens > cache.size) { if (cache.head + n_tokens > n_ctx) {
n_tested += cache.size - cache.head; n_tested += n_ctx - cache.head;
cache.head = 0; cache.head = 0;
continue; continue;
} }
@ -2663,7 +2656,7 @@ static bool llama_kv_cache_find_slot(
break; break;
} }
if (n_tested >= cache.size) { if (n_tested >= n_ctx) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false; return false;
} }
@ -5173,19 +5166,28 @@ static bool llm_load_tensors(
case LLM_ARCH_MINICPM: case LLM_ARCH_MINICPM:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
// output // if output is NULL, init from the input tok embed
{ if (model.output == NULL) {
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
if (model.arch != LLM_ARCH_MINICPM){ ml.n_created--; // artificial tensor
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); ml.size_data += ggml_nbytes(model.output);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
} }
// // output
// {
// model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
// if (model.arch != LLM_ARCH_MINICPM){
// model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// // if output is NULL, init from the input tok embed
// if (model.output == NULL) {
// model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
// }
// }
// }
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i); ggml_context * ctx_split = ctx_for_layer_split(i);
@ -10144,7 +10146,9 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// scale the input embeddings // scale the input embeddings
inpL = ggml_scale(ctx0, inpL, scale_embd); if (batch.token) {
inpL = ggml_scale(ctx0, inpL, scale_embd);
}
cb(inpL, "inp_scaled", -1); cb(inpL, "inp_scaled", -1);
// inp_pos - contains the positions // inp_pos - contains the positions
@ -10260,7 +10264,8 @@ struct llm_build_context {
cb(cur, "lmhead_scaling", -1); cb(cur, "lmhead_scaling", -1);
// lm_head // lm_head
cur = ggml_mul_mat(ctx0, model.tok_embd, cur); // cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -16216,16 +16221,33 @@ struct llama_model * llama_load_model_from_file(
} }
model->rpc_servers.push_back(servers); model->rpc_servers.push_back(servers);
} }
int status = llama_model_load(path_model, *model, params); // int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0); // GGML_ASSERT(status <= 0);
if (status < 0) { // if (status < 0) {
if (status == -1) { // if (status == -1) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); // LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
} else if (status == -2) { // } else if (status == -2) {
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); // LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
// }
// delete model;
// return nullptr;
// }
try {
int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (status == -1) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
} else if (status == -2) {
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
}
delete model;
return nullptr;
} }
} catch (...) {
LLAMA_LOG_ERROR("%s: exception loading model\n", __func__);
delete model; delete model;
return nullptr; throw;
} }
return model; return model;
@ -16612,6 +16634,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// these models do not use RoPE // these models do not use RoPE
case LLM_ARCH_GPT2: case LLM_ARCH_GPT2:
case LLM_ARCH_GPTJ: case LLM_ARCH_GPTJ:
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_MPT: case LLM_ARCH_MPT:
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM: case LLM_ARCH_BLOOM:
@ -16649,7 +16672,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHI3: case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA:
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
case LLM_ARCH_GPTNEOX:
return LLAMA_ROPE_TYPE_NEOX; return LLAMA_ROPE_TYPE_NEOX;
// all model arches should be listed explicitly here // all model arches should be listed explicitly here