mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
llama : add Qwen2VL support + multimodal RoPE (#10361)
* Barebone Qwen2VL LLM convertor * Add Qwen2VL cli entrypoint * [WIP] add qwen2vl arch * Verify m-rope output * Add vl-rope/2d-rope support for qwen2vl ViT * update qwen2vl cli tool * update 5D tensor op workaround * [WIP] qwen2vl vision model * make batch and clip utils compatible with qwen2vl * [WIP] create inference workflow, gguf convert script but fix * correcting vision-rope behavior, add the missing last layer back to ViT * add arg parser to qwen2vl_surgery * replace variable size array with vector * cuda-gdb cmake preset * add fp32 mrope, vision rope kernel * add fp16 support for qwen2vl and m-rope * add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION` * fix rope op mode switching, out dated func args * update `llama_hparams` * update to keep up stream changes * resolve linter, test errors * add makefile entry, update speical image padding token * add mrope unit test, fix few compiler warnings * rename `mrope` related function, params * minor updates on debug util, bug fixs * add `m-rope` testcase to `test-backend-ops` * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix traililng whitespce * store `llama_hparams.rope_sections` with fixed size array * update position id tensor size check in GGML_OP_ROPE * minor updates * update `ggml_backend_*_supports_op` of unsupported backends * remote old `rope_section` compare operator --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
56eea0781c
commit
ba1cb19cdd
9
Makefile
9
Makefile
@ -22,6 +22,7 @@ BUILD_TARGETS = \
|
||||
llama-infill \
|
||||
llama-llava-cli \
|
||||
llama-minicpmv-cli\
|
||||
llama-qwen2vl-cli\
|
||||
llama-lookahead \
|
||||
llama-lookup \
|
||||
llama-lookup-create \
|
||||
@ -1404,6 +1405,14 @@ llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
|
||||
|
||||
llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \
|
||||
examples/llava/llava.cpp \
|
||||
examples/llava/llava.h \
|
||||
examples/llava/clip.cpp \
|
||||
examples/llava/clip.h \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
|
||||
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
swift: examples/batched.swift
|
||||
(cd examples/batched.swift; make build)
|
||||
|
@ -110,6 +110,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
|
||||
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
|
||||
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
|
||||
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -2001,6 +2001,29 @@ class Qwen2Model(Model):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@Model.register("Qwen2VLForConditionalGeneration")
|
||||
class Qwen2VLModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2VL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
mrope_section = self.hparams["rope_scaling"]["mrope_section"]
|
||||
mrope_section += [0] * max(0, 4 - len(mrope_section))
|
||||
self.gguf_writer.add_rope_dimension_sections(mrope_section)
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, data in super().get_tensors():
|
||||
if name.startswith("visual."):
|
||||
continue
|
||||
yield name, data
|
||||
|
||||
|
||||
@Model.register("Qwen2MoeForCausalLM")
|
||||
class Qwen2MoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2MOE
|
||||
|
@ -43,3 +43,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-qwen2vl-cli)
|
||||
add_executable(${TARGET} qwen2vl-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
@ -102,7 +102,9 @@ static std::string format(const char * fmt, ...) {
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_USE_SILU "clip.use_silu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
@ -129,7 +131,8 @@ static std::string format(const char * fmt, ...) {
|
||||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||
#define TN_PATCH_BIAS "v.patch_embd.bias"
|
||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
@ -163,6 +166,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -171,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
};
|
||||
|
||||
|
||||
@ -463,7 +468,8 @@ struct clip_vision_model {
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
struct ggml_tensor * patch_embeddings;
|
||||
struct ggml_tensor * patch_embeddings_0;
|
||||
struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
||||
struct ggml_tensor * patch_bias;
|
||||
struct ggml_tensor * position_embeddings;
|
||||
|
||||
@ -553,6 +559,7 @@ struct clip_ctx {
|
||||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
bool has_minicpmv_projector = false;
|
||||
bool has_qwen2vl_merger = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
struct clip_vision_model vision_model;
|
||||
@ -561,6 +568,7 @@ struct clip_ctx {
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
bool has_class_embedding = true;
|
||||
@ -606,14 +614,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
// use the image's native resolution when image is avaible
|
||||
if (is_inf) {
|
||||
// if (imgs->data->nx && imgs->data->ny) {
|
||||
image_size_width = imgs->data->nx;
|
||||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int patches_w = image_size_width / patch_size;
|
||||
const int patches_h = image_size_height / patch_size;
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
@ -634,10 +654,30 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
|
||||
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||
inp = ggml_reshape_3d(
|
||||
ctx0, inp,
|
||||
hidden_size, patches_w * patches_h, batch_size);
|
||||
}
|
||||
else {
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
}
|
||||
|
||||
if (ctx->has_patch_bias) {
|
||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
@ -659,12 +699,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
}
|
||||
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
@ -688,7 +730,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
|
||||
// TODO: figure out why we doing thing in this way ???
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
@ -710,8 +753,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
Q = ggml_rope_multi(
|
||||
ctx0, Q, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
@ -719,6 +767,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
K = ggml_rope_multi(
|
||||
ctx0, K, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
@ -758,6 +811,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else if (ctx->use_silu) {
|
||||
cur = ggml_silu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
@ -769,6 +824,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
@ -1030,6 +1086,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
// GELU activation
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
@ -1206,6 +1275,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
|
||||
if (idx != -1) {
|
||||
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
|
||||
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
@ -1214,6 +1287,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_SILU);
|
||||
new_clip->use_silu = gguf_get_val_bool(ctx, idx);
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
new_clip->use_silu = false;
|
||||
}
|
||||
|
||||
if (verbosity >= 1) {
|
||||
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
@ -1389,11 +1469,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
} catch(const std::exception& /*e*/) {
|
||||
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
|
||||
}
|
||||
try {
|
||||
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
new_clip->has_qwen2vl_merger = false;
|
||||
}
|
||||
|
||||
// LLaVA projection
|
||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
@ -1481,6 +1566,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
|
||||
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
@ -1519,6 +1610,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
|
||||
clip_image_f32_batch batch;
|
||||
batch.size = 1;
|
||||
batch.data = nullptr;
|
||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
||||
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
||||
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
||||
@ -1532,6 +1624,10 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size
|
||||
ctx_clip->load_image_size = load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
|
||||
return ctx_clip->load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_image_size_init() {
|
||||
struct clip_image_size * load_image_size = new struct clip_image_size();
|
||||
load_image_size->width = 448;
|
||||
@ -1984,6 +2080,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
}
|
||||
return true;
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
clip_image_u8 * resized = clip_image_u8_init();
|
||||
auto patch_size = clip_patch_size(ctx) * 2;
|
||||
int nx = ceil((float)img->nx / patch_size) * patch_size;
|
||||
int ny = ceil((float)img->ny / patch_size) * patch_size;
|
||||
bicubic_resize(*img, *resized, nx, ny);
|
||||
|
||||
res_imgs->data = new clip_image_f32[1];
|
||||
// clip_image_f32 * res = clip_image_f32_init();
|
||||
normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
|
||||
// res_imgs->data[0] = *res;
|
||||
res_imgs->size = 1;
|
||||
|
||||
// clip_image_f32_free(res);
|
||||
clip_image_u8_free(resized);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pad_to_square = true;
|
||||
if (!ctx->has_vision_encoder) {
|
||||
@ -2173,6 +2286,13 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
|
||||
clip_image_f32 img;
|
||||
img.nx = img_w;
|
||||
img.ny = img_h;
|
||||
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
@ -2194,6 +2314,13 @@ const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
||||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
clip_image_f32 img;
|
||||
img.nx = ctx->vision_model.hparams.image_size;
|
||||
img.ny = ctx->vision_model.hparams.image_size;
|
||||
return clip_n_patches_by_img(ctx, &img);
|
||||
}
|
||||
|
||||
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->vision_model.hparams;
|
||||
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
@ -2207,6 +2334,11 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
n_patches = 64;
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
n_patches = x_patch * y_patch;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
@ -2335,7 +2467,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
|
||||
image_size_width = imgs->data[0].nx;
|
||||
image_size_height = imgs->data[0].ny;
|
||||
}
|
||||
@ -2355,7 +2487,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
for (size_t i = 0; i < imgs->size; i++) {
|
||||
const int nx = imgs->data[i].nx;
|
||||
const int ny = imgs->data[i].ny;
|
||||
if (!ctx->has_minicpmv_projector) {
|
||||
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
|
||||
@ -2413,9 +2545,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||
|
||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||
for(int i=0;i<pos_w * pos_h;++i){
|
||||
for(int j=0;j<embed_dim;++j){
|
||||
pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j];
|
||||
for(int i=0;i < pos_w * pos_h; ++i){
|
||||
for(int j=0; j < embed_dim; ++j){
|
||||
pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
@ -2435,7 +2567,34 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
const int pw = image_size_width / patch_size;
|
||||
const int ph = image_size_height / patch_size;
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
||||
int ptr = 0;
|
||||
for (int y = 0; y < ph; y+=2)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=2)
|
||||
{
|
||||
for (int dy = 0; dy < 2; dy++) {
|
||||
for (int dx = 0; dx < 2; dx++) {
|
||||
positions_data[ptr] = y + dy;
|
||||
positions_data[num_patches + ptr] = x + dx;
|
||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
@ -2444,16 +2603,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2626,6 +2785,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
@ -2637,3 +2799,21 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->has_qwen2vl_merger;
|
||||
}
|
||||
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
for (int i = 0; i < h*w*3; i++)
|
||||
{
|
||||
clip_img.buf[i] = img[i];
|
||||
}
|
||||
clip_img.nx = w;
|
||||
clip_img.ny = h;
|
||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||
return true;
|
||||
}
|
||||
|
@ -45,6 +45,7 @@ CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
|
||||
|
||||
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
|
||||
@ -55,11 +56,13 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
|
||||
CLIP_API struct clip_image_size * clip_image_size_init();
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
@ -86,6 +89,9 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
|
||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -259,25 +259,33 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
|
||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
||||
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
||||
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
const int64_t t_img_enc_step_start_us = ggml_time_us();
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
int patch_size=14;
|
||||
load_image_size->width = img_res_v.data[i].nx;
|
||||
load_image_size->height = img_res_v.data[i].ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
|
||||
bool encoded = false;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
else {
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
return false;
|
||||
@ -290,8 +298,11 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
|
||||
int n_img_pos_out = 0;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip));
|
||||
n_img_pos_out += clip_n_patches(ctx_clip);
|
||||
std::memcpy(
|
||||
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
|
||||
image_embd_v[i],
|
||||
clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
|
||||
}
|
||||
*n_img_pos = n_img_pos_out;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
@ -387,7 +398,13 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
num_max_patches = 10;
|
||||
}
|
||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
|
||||
float * image_embd;
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
|
||||
image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny));
|
||||
} else {
|
||||
image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
|
||||
}
|
||||
if (!image_embd) {
|
||||
LOG_ERR("Unable to allocate memory for image embeddings\n");
|
||||
return false;
|
||||
|
158
examples/llava/qwen2_vl_surgery.py
Normal file
158
examples/llava/qwen2_vl_surgery.py
Normal file
@ -0,0 +1,158 @@
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLProcessor,
|
||||
AutoProcessor,
|
||||
Qwen2VLConfig
|
||||
)
|
||||
|
||||
|
||||
VISION = "clip.vision"
|
||||
|
||||
|
||||
def k(raw_key: str, arch: str) -> str:
|
||||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
|
||||
def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
|
||||
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.data_type == 'fp32':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float32
|
||||
ftype = 0
|
||||
elif args.data_type == 'fp16':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float16
|
||||
ftype = 1
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
model_name = args.model_name
|
||||
print("model_name: ", model_name)
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
|
||||
if os.path.isdir(model_name):
|
||||
if model_name.endswith(os.sep):
|
||||
model_name = model_name[:-1]
|
||||
model_name = os.path.basename(model_name)
|
||||
fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
|
||||
|
||||
fout = GGUFWriter(path=fname_out, arch="clip")
|
||||
fout.add_description("image encoder for Qwen2VL")
|
||||
|
||||
fout.add_file_type(ftype)
|
||||
fout.add_bool("clip.has_text_encoder", False)
|
||||
fout.add_bool("clip.has_vision_encoder", True)
|
||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
|
||||
print(cfg.vision_config)
|
||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
||||
fout.add_bool("clip.use_silu", True)
|
||||
fout.add_bool("clip.use_gelu", False)
|
||||
elif 'gelu' in cfg.vision_config.hidden_act.lower():
|
||||
fout.add_bool("clip.use_silu", False)
|
||||
fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
tensor_map = find_vision_tensors(qwen2vl, np_dtype)
|
||||
for name, data in tensor_map.items():
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
|
||||
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
|
||||
fout.add_name(model_name)
|
||||
"""
|
||||
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
|
||||
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
|
||||
"""
|
||||
|
||||
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
|
||||
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
|
||||
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
|
||||
|
||||
fout.write_header_to_file()
|
||||
fout.write_kv_data_to_file()
|
||||
fout.write_tensors_to_file()
|
||||
fout.close()
|
||||
print("save model as: ", fname_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
581
examples/llava/qwen2vl-cli.cpp
Normal file
581
examples/llava/qwen2vl-cli.cpp
Normal file
@ -0,0 +1,581 @@
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
#ifdef NDEBUG
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
|
||||
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
|
||||
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
const int patch_size = 14 * 2;
|
||||
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
|
||||
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
|
||||
auto img_tokens = image_embed->n_image_pos;
|
||||
// llama_pos mrope_pos[img_tokens * 4];
|
||||
std::vector<llama_pos> mrope_pos;
|
||||
mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int y = 0; y < ph; y++)
|
||||
{
|
||||
for (int x = 0; x < pw; x++)
|
||||
{
|
||||
int i = y * pw + x;
|
||||
mrope_pos[i] = *st_pos_id;
|
||||
mrope_pos[i + img_tokens] = *st_pos_id + y;
|
||||
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
|
||||
mrope_pos[i + img_tokens * 3] = 0;
|
||||
}
|
||||
}
|
||||
*st_pos_id += std::max(pw, ph);
|
||||
|
||||
int processed = 0;
|
||||
std::vector<llama_pos> batch_mrope_pos;
|
||||
batch_mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int i = 0; i < img_tokens; i += n_batch) {
|
||||
int n_eval = img_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
|
||||
// llama_pos batch_mrope_pos[n_eval * 4];
|
||||
std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
|
||||
memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||
|
||||
llama_batch batch = {
|
||||
int32_t(n_eval), // n_tokens
|
||||
nullptr, // token
|
||||
(image_embed->embed+i*n_embd), // embed
|
||||
batch_mrope_pos.data(), // pos
|
||||
nullptr, // n_seq_id
|
||||
nullptr, // seq_id
|
||||
nullptr, // logits
|
||||
};
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
processed += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
|
||||
int N = (int) tokens.size();
|
||||
std::vector<llama_pos> pos;
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
auto batch = llama_batch_get_one(&tokens[i], n_eval);
|
||||
// TODO: add mrope pos ids somewhere else
|
||||
pos.resize(batch.n_tokens * 4);
|
||||
std::fill(pos.begin(), pos.end(), 0);
|
||||
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
||||
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
*st_pos_id += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct common_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past, int * st_pos_id) {
|
||||
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
||||
common_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = common_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past, st_pos_id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
||||
static const char* IMG_BASE64_TAG_END = "\">";
|
||||
|
||||
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
||||
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
||||
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
||||
}
|
||||
|
||||
static bool prompt_contains_image(const std::string& prompt) {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
return (begin != std::string::npos);
|
||||
}
|
||||
|
||||
// replaces the base64 image tag in the prompt with `replacement`
|
||||
static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
|
||||
size_t img_base64_str_start, img_base64_str_end;
|
||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||
LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
||||
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
||||
|
||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: could not load image from base64 string.\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
if (begin == std::string::npos || end == std::string::npos) {
|
||||
return prompt;
|
||||
}
|
||||
auto pre = prompt.substr(0, begin);
|
||||
auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
|
||||
return pre + replacement + post;
|
||||
}
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\n example usage:\n");
|
||||
LOG("\n %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
|
||||
|
||||
// load and preprocess the image
|
||||
llava_image_embed * embed = NULL;
|
||||
auto prompt = params->prompt;
|
||||
if (prompt_contains_image(prompt)) {
|
||||
if (!params->image.empty()) {
|
||||
LOG_INF("using base64 encoded image instead of command line image path\n");
|
||||
}
|
||||
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: can't load image from prompt\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
|
||||
int n_past = 0;
|
||||
int cur_pos_id = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
||||
std::string system_prompt, user_prompt;
|
||||
size_t image_pos = prompt.find("<|vision_start|>");
|
||||
if (image_pos != std::string::npos) {
|
||||
// new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
|
||||
system_prompt = prompt.substr(0, image_pos);
|
||||
user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
|
||||
LOG_INF("system_prompt: %s\n", system_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
LOG_INF("user_prompt: %s\n", user_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// llava-1.5 native mode
|
||||
system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
|
||||
user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
|
||||
if (image_embed != nullptr) {
|
||||
auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
|
||||
qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG("\n");
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
LOG("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
common_sampler_free(smpl);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static struct llama_model * llava_init(common_params * params) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(*params);
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
||||
const char * clip_path = params->mmproj.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
|
||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_free_model(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
||||
static void debug_test_mrope_2d() {
|
||||
// 1. Initialize backend
|
||||
ggml_backend_t backend = NULL;
|
||||
std::string backend_name = "";
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
backend = ggml_backend_cuda_init(0); // init device 0
|
||||
backend_name = "cuda";
|
||||
if (!backend) {
|
||||
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
#endif
|
||||
// if there aren't GPU Backends fallback to CPU backend
|
||||
if (!backend) {
|
||||
backend = ggml_backend_cpu_init();
|
||||
backend_name = "cpu";
|
||||
}
|
||||
|
||||
// Calculate the size needed to allocate
|
||||
size_t ctx_size = 0;
|
||||
ctx_size += 2 * ggml_tensor_overhead(); // tensors
|
||||
// no need to allocate anything else!
|
||||
|
||||
// 2. Allocate `ggml_context` to store tensor data
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
|
||||
ggml_set_name(pos, "pos");
|
||||
ggml_set_input(pos);
|
||||
|
||||
std::vector<float> dummy_q;
|
||||
dummy_q.resize(128 * 12 * 30);
|
||||
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
||||
// memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
||||
|
||||
std::vector<int> pos_id;
|
||||
pos_id.resize(30 * 4);
|
||||
for (int i = 0; i < 30; i ++) {
|
||||
pos_id[i] = i;
|
||||
pos_id[i + 30] = i + 10;
|
||||
pos_id[i + 60] = i + 20;
|
||||
pos_id[i + 90] = i + 30;
|
||||
}
|
||||
int sections[4] = {32, 32, 0, 0};
|
||||
|
||||
// 4. Allocate a `ggml_backend_buffer` to store all tensors
|
||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
|
||||
// 5. Copy tensor data from main memory (RAM) to backend buffer
|
||||
ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
|
||||
ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
|
||||
|
||||
// 6. Create a `ggml_cgraph` for mul_mat operation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx_cgraph = NULL;
|
||||
|
||||
// create a temporally context to build the graph
|
||||
struct ggml_init_params params0 = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx_cgraph = ggml_init(params0);
|
||||
gf = ggml_new_graph(ctx_cgraph);
|
||||
|
||||
struct ggml_tensor * result0 = ggml_rope_multi(
|
||||
ctx_cgraph, inp_raw, pos, nullptr,
|
||||
128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
|
||||
0, 1, 32, 1);
|
||||
|
||||
// Add "result" tensor and all of its dependencies to the cgraph
|
||||
ggml_build_forward_expand(gf, result0);
|
||||
|
||||
// 7. Create a `ggml_gallocr` for cgraph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// 9. Run the computation
|
||||
int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
}
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// 10. Retrieve results (output tensors)
|
||||
// in this example, output tensor is always the last tensor in the graph
|
||||
struct ggml_tensor * result = result0;
|
||||
// struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
|
||||
float * result_data = (float *)malloc(ggml_nbytes(result));
|
||||
// because the tensor data is stored in device buffer, we need to copy it back to RAM
|
||||
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||
const std::string bin_file = "mrope_2d_" + backend_name +".bin";
|
||||
std::ofstream outFile(bin_file, std::ios::binary);
|
||||
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to " + bin_file << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
|
||||
free(result_data);
|
||||
// 11. Free memory and exit
|
||||
ggml_free(ctx_cgraph);
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
ggml_backend_buffer_free(buffer);
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
static void debug_dump_img_embed(struct llava_context * ctx_llava) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||
int ne = n_embd * 4;
|
||||
float vals[56 * 56 * 3];
|
||||
// float embd[ne];
|
||||
std::vector<float> embd;
|
||||
embd.resize(ne);
|
||||
|
||||
for (int i = 0; i < 56*56; i++)
|
||||
{
|
||||
for (int c = 0; c < 3; c++)
|
||||
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
||||
}
|
||||
|
||||
clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data());
|
||||
|
||||
std::ofstream outFile("img_embed.bin", std::ios::binary);
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
|
||||
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to mrope.bin" << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto * model = llava_init(¶ms);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (prompt_contains_image(params.prompt)) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, "");
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#ifndef NDEBUG
|
||||
} else if (params.image[0].empty()) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
debug_test_mrope_2d();
|
||||
debug_dump_img_embed(ctx_llava);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#endif
|
||||
} else {
|
||||
for (auto & image : params.image) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
|
||||
return 0;
|
||||
}
|
@ -237,7 +237,9 @@
|
||||
#define GGML_EXIT_SUCCESS 0
|
||||
#define GGML_EXIT_ABORTED 1
|
||||
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
#define GGML_ROPE_TYPE_MROPE 8
|
||||
#define GGML_ROPE_TYPE_VISION 24
|
||||
|
||||
#define GGUF_MAGIC "GGUF"
|
||||
|
||||
@ -1443,6 +1445,22 @@ extern "C" {
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rope_multi(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int sections[4],
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -1747,6 +1747,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
if (*ext_factor != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_UPSCALE: {
|
||||
|
@ -9133,6 +9133,64 @@ static void ggml_rope_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_mrope_cache_init(
|
||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
||||
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta_t = theta_base_t;
|
||||
float theta_h = theta_base_h;
|
||||
float theta_w = theta_base_w;
|
||||
float theta_e = theta_base_e; // extra position id for vision encoder
|
||||
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||
int sec_w = sections[1] + sections[0];
|
||||
int sec_e = sections[2] + sec_w;
|
||||
GGML_ASSERT(sect_dims <= ne0);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
int sector = (i0 / 2) % sect_dims;
|
||||
if (indep_sects) {
|
||||
// compute theta independently for each dim sections
|
||||
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
||||
if (sector == 0) {
|
||||
theta_t = theta_base_t;
|
||||
}
|
||||
else if (sector == sections[0]) {
|
||||
theta_h = theta_base_h;;
|
||||
}
|
||||
else if (sector == sec_w) {
|
||||
theta_w = theta_base_w;
|
||||
}
|
||||
else if (sector == sec_e) {
|
||||
theta_e = theta_base_e;
|
||||
}
|
||||
}
|
||||
|
||||
float theta = theta_t;
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
}
|
||||
else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
}
|
||||
|
||||
rope_yarn(
|
||||
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||
);
|
||||
cache[i0 + 1] *= sin_sign;
|
||||
|
||||
theta_t *= theta_scale;
|
||||
theta_w *= theta_scale;
|
||||
theta_h *= theta_scale;
|
||||
theta_e *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
@ -9143,6 +9201,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@ -9156,6 +9215,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -9188,6 +9248,16 @@ static void ggml_compute_forward_rope_f32(
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne0/2);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
@ -9203,18 +9273,63 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
if (!is_neox) {
|
||||
if (is_neox || is_mrope) {
|
||||
if (is_vision){
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
@ -9228,8 +9343,10 @@ static void ggml_compute_forward_rope_f32(
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
@ -9239,19 +9356,20 @@ static void ggml_compute_forward_rope_f32(
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -9269,6 +9387,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@ -9281,6 +9400,8 @@ static void ggml_compute_forward_rope_f16(
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -9313,6 +9434,16 @@ static void ggml_compute_forward_rope_f16(
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne0/2);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
@ -9330,16 +9461,61 @@ static void ggml_compute_forward_rope_f16(
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
if (!is_neox) {
|
||||
if (is_neox || is_mrope) {
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
@ -9353,8 +9529,10 @@ static void ggml_compute_forward_rope_f16(
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
@ -9364,19 +9542,19 @@ static void ggml_compute_forward_rope_f16(
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,11 @@ struct rope_corr_dims {
|
||||
float v[2];
|
||||
};
|
||||
|
||||
|
||||
struct mrope_sections {
|
||||
int v[4];
|
||||
};
|
||||
|
||||
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
@ -108,6 +113,105 @@ static __global__ void rope_neox(
|
||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<typename T, bool has_ff>
|
||||
static __global__ void rope_multi(
|
||||
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int i = row*ne0 + i0/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
int sec_w = sections.v[1] + sections.v[0];
|
||||
int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + n_dims/2];
|
||||
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<typename T, bool has_ff>
|
||||
static __global__ void rope_vision(
|
||||
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int i = row*ne0 + i0/2;
|
||||
const int i2 = row/p_delta_rows; // i2-th tokens
|
||||
|
||||
int sect_dims = sections.v[0] + sections.v[1];
|
||||
int sec_w = sections.v[1] + sections.v[0];
|
||||
int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
const int p = sector;
|
||||
theta_base = pos[i2]*powf(theta_scale, p);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
const int p = sector - sections.v[0];
|
||||
theta_base = pos[i2 + ne2]*powf(theta_scale, p);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + n_dims];
|
||||
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_norm_cuda(
|
||||
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
@ -156,6 +260,56 @@ static void rope_neox_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections
|
||||
);
|
||||
} else {
|
||||
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_vision_cuda(
|
||||
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
|
||||
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections
|
||||
);
|
||||
} else {
|
||||
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_norm_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
@ -185,6 +339,38 @@ static void rope_neox_cuda_f32(
|
||||
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_multi_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_multi_cuda_f32(
|
||||
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_vision_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_vision_cuda_f32(
|
||||
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@ -201,8 +387,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne00 = src0->ne[0]; // head dims
|
||||
const int64_t ne01 = src0->ne[1]; // num heads
|
||||
const int64_t ne02 = src0->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
@ -210,6 +397,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
mrope_sections sections;
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base;
|
||||
@ -225,8 +413,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
}
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1_d;
|
||||
|
||||
@ -253,6 +452,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else if (is_mrope && !is_vision) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else if (is_vision) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_vision_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_vision_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_norm_cuda_f32(
|
||||
|
@ -1419,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
@ -1125,8 +1125,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
@ -3026,7 +3036,9 @@ static void ggml_metal_encode_node(
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
// make sure we have one or more position id(ne10) per token(ne02)
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
|
||||
const int nth = MIN(1024, ne00);
|
||||
|
||||
|
@ -4488,7 +4488,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
// TODO: add support for the new F32 operations
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
|
@ -7687,7 +7687,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_REPEAT:
|
||||
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
}
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
@ -3517,15 +3517,18 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||
}
|
||||
|
||||
int sections[4] = {0, 0, 0, 0};
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, §ions, sizeof(int)*4);
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
@ -3547,6 +3550,53 @@ struct ggml_tensor * ggml_rope(
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_multi(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int sections[4],
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
// Multimodal Rotary Position Embedding
|
||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
||||
|
||||
if (c) {
|
||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(¶ms[11], sections, sizeof(int)*4);
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
@ -131,6 +131,7 @@ class Keys:
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
@ -226,6 +227,7 @@ class MODEL_ARCH(IntEnum):
|
||||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
@ -388,6 +390,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.QWEN: "qwen",
|
||||
MODEL_ARCH.QWEN2: "qwen2",
|
||||
MODEL_ARCH.QWEN2MOE: "qwen2moe",
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
@ -772,6 +775,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -751,6 +751,9 @@ class GGUFWriter:
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
|
||||
|
||||
def add_rope_freq_base(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
|
@ -108,9 +108,11 @@ extern "C" {
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
LLAMA_ROPE_TYPE_NONE = -1,
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||
LLAMA_ROPE_TYPE_NONE = -1,
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
||||
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
||||
};
|
||||
|
||||
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
|
||||
|
178
src/llama.cpp
178
src/llama.cpp
@ -163,6 +163,7 @@ enum llm_arch {
|
||||
LLM_ARCH_QWEN,
|
||||
LLM_ARCH_QWEN2,
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PLAMO,
|
||||
@ -217,6 +218,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_QWEN, "qwen" },
|
||||
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
@ -308,6 +310,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
@ -424,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
@ -898,6 +902,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN2VL,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
{
|
||||
@ -2474,11 +2495,12 @@ struct llama_hparams {
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
int rope_sections[4];
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
@ -2535,6 +2557,9 @@ struct llama_hparams {
|
||||
|
||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
|
||||
if (std::equal(std::begin(this->rope_sections),
|
||||
std::end(this->rope_sections),
|
||||
std::begin(other.rope_sections))) return true;
|
||||
|
||||
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||
@ -3378,6 +3403,11 @@ struct llama_context {
|
||||
// whether we are computing encoder output or decoder output
|
||||
bool is_encoding = false;
|
||||
|
||||
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
||||
// number of position id each token get, 1 for each token in most cases.
|
||||
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
||||
int n_pos_per_token = 1;
|
||||
|
||||
// output of the encoder part of the encoder-decoder models
|
||||
std::vector<float> embd_enc;
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
@ -5747,6 +5777,13 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
std::array<int, 4> section_dims;
|
||||
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, section_dims, 4, true);
|
||||
std::copy(section_dims.begin(), section_dims.begin() + 4, std::begin(hparams.rope_sections));
|
||||
}
|
||||
// fall through
|
||||
case LLM_ARCH_QWEN2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -8167,6 +8204,7 @@ static bool llm_load_tensors(
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@ -12556,6 +12594,124 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_qwen2vl() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
|
||||
cb(lctx.inp_pos, "inp_pos", -1);
|
||||
ggml_set_input(lctx.inp_pos);
|
||||
struct ggml_tensor * inp_pos = lctx.inp_pos;
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_multi(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_multi(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_qwen2moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
@ -16657,6 +16813,11 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_qwen2();
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
lctx.n_pos_per_token = 4;
|
||||
result = llm.build_qwen2vl();
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
{
|
||||
result = llm.build_qwen2moe();
|
||||
@ -16875,8 +17036,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
||||
|
||||
if (ubatch.pos && lctx.inp_pos) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
auto n_pos = lctx.n_pos_per_token;
|
||||
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
@ -20009,6 +20170,9 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_MINICPM3:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
return LLAMA_ROPE_TYPE_MROPE;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
GGML_ABORT("unknown architecture");
|
||||
|
@ -2201,7 +2201,15 @@ struct test_rope : public test_case {
|
||||
ggml_set_name(a, "a");
|
||||
}
|
||||
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
ggml_tensor * pos;
|
||||
if (is_mrope || is_vision) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
|
||||
} else {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
||||
}
|
||||
ggml_set_name(pos, "pos");
|
||||
|
||||
ggml_tensor * freq = nullptr;
|
||||
@ -2210,7 +2218,20 @@ struct test_rope : public test_case {
|
||||
ggml_set_name(freq, "freq");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
ggml_tensor * out;
|
||||
if (is_mrope) {
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims/4 > 0);
|
||||
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
||||
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
} else {
|
||||
GGML_ASSERT(n_dims/3 > 0);
|
||||
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
||||
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
}
|
||||
} else {
|
||||
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
}
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@ -2220,11 +2241,12 @@ struct test_rope : public test_case {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// pos
|
||||
std::vector<int> data(ne_a[2]);
|
||||
for (int i = 0; i < ne_a[2]; i++) {
|
||||
const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
|
||||
std::vector<int> data(num_pos_ids);
|
||||
for (int i = 0; i < num_pos_ids; i++) {
|
||||
data[i] = rand() % n_ctx;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
|
||||
ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
|
||||
} else {
|
||||
if (t->ne[0] == n_dims/2) {
|
||||
// frequency factors in the range [0.9f, 1.1f]
|
||||
@ -3813,6 +3835,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
|
||||
}
|
||||
|
||||
if (all) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
|
||||
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
|
||||
}
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||
struct ggml_tensor * x;
|
||||
|
||||
// rope f32
|
||||
for (int m = 0; m < 3; ++m) {
|
||||
for (int m = 0; m < 5; ++m) {
|
||||
const int ndims = 4;
|
||||
|
||||
const int64_t n_rot = 128;
|
||||
@ -147,28 +147,69 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||
const int n_past_0 = 100;
|
||||
const int n_past_2 = 33;
|
||||
|
||||
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
|
||||
for (int i = 0; i < ne[2]; ++i) {
|
||||
((int32_t *) p0->data)[i] = n_past_0 + i;
|
||||
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
|
||||
((int32_t *) p2->data)[i] = n_past_2 + i;
|
||||
}
|
||||
|
||||
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
|
||||
const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
|
||||
|
||||
struct ggml_tensor * r0;
|
||||
struct ggml_tensor * r1;
|
||||
struct ggml_tensor * r2;
|
||||
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
int mode = -1;
|
||||
|
||||
// 100, 101, 102, ..., 172
|
||||
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
|
||||
// -67, -67, -67, ..., -67
|
||||
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
||||
if (m < 3) {
|
||||
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
|
||||
// 33, 34, 35, ..., 105
|
||||
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
|
||||
for (int i = 0; i < ne[2]; ++i) {
|
||||
((int32_t *) p0->data)[i] = n_past_0 + i;
|
||||
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
|
||||
((int32_t *) p2->data)[i] = n_past_2 + i;
|
||||
}
|
||||
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
|
||||
mode = m == 0 ? 0 : m == 1 ? 2 : 4;
|
||||
|
||||
// 100, 101, 102, ..., 172
|
||||
r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
|
||||
// -67, -67, -67, ..., -67
|
||||
r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
||||
|
||||
// 33, 34, 35, ..., 105
|
||||
r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
|
||||
} else {
|
||||
// testing multi-dimension rope position embedding mode
|
||||
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
|
||||
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
|
||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
|
||||
|
||||
int sections[4] = {16, 24, 24, 0};
|
||||
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
|
||||
|
||||
for (int i = 0; i < ne[2]; ++i) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
((int32_t *) p0->data)[i + ne[2] * j] = n_past_0 + i + j;
|
||||
((int32_t *) p1->data)[i + ne[2] * j] = n_past_2 - n_past_0;
|
||||
((int32_t *) p2->data)[i + ne[2] * j] = n_past_2 + i + j;
|
||||
}
|
||||
}
|
||||
|
||||
// [[100, 101, 102, ..., 172],
|
||||
// [101, 102, 103, ..., 173],
|
||||
// [102, 103, 104, ..., 174]]
|
||||
r0 = ggml_rope_multi(
|
||||
ctx0, x, p0, nullptr,
|
||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||
// [[-67, -67, -67, ..., -67]
|
||||
// [-67, -67, -67, ..., -67]
|
||||
// [-67, -67, -67, ..., -67]]
|
||||
r1 = ggml_rope_multi(
|
||||
ctx0, r0, p1, nullptr,
|
||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||
|
||||
// [[33, 34, 35, ..., 105]
|
||||
// [34, 35, 36, ..., 106]
|
||||
// [35, 36, 37, ..., 107]]
|
||||
r2 = ggml_rope_multi(
|
||||
ctx0, x, p2, nullptr,
|
||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user