From 3e3357fd77bcf5bd8cfcc53ca53d4a5532e67e1b Mon Sep 17 00:00:00 2001 From: tc-mb <157115220+tc-mb@users.noreply.github.com> Date: Wed, 22 Jan 2025 15:35:48 +0800 Subject: [PATCH 01/11] llava : support Minicpm-omni (#11289) * init * add readme * update readme * no use make * update readme * update fix code * fix editorconfig-checker * no change convert py * use clip_image_u8_free --- examples/llava/README-minicpmo2.6.md | 46 +++++++++++++++++++ examples/llava/clip.cpp | 29 +++++++++++- examples/llava/llava.cpp | 13 ++---- examples/llava/minicpmv-cli.cpp | 10 +++- .../minicpmv-convert-image-encoder-to-gguf.py | 15 ++++-- examples/llava/minicpmv-surgery.py | 2 +- 6 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 examples/llava/README-minicpmo2.6.md diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md new file mode 100644 index 000000000..8713a43d6 --- /dev/null +++ b/examples/llava/README-minicpmo2.6.md @@ -0,0 +1,46 @@ +## MiniCPM-o 2.6 +Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. + +### Prepare models and code + +Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpm-omni +``` + +### Usage of MiniCPM-o 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build llama.cpp using `CMake`: +https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md + +```bash +cmake -B build +cmake --build build --config Release +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7a8a3156b..24073c5a9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -2041,6 +2049,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } return true; } else if (ctx->has_qwen2vl_merger) { @@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); @@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index c598caf3d..2cac7933d 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } else { - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); - } + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); } if (!encoded) { @@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 38c44e130..53d902d61 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e else if (has_minicpmv_projector == 3) { system_prompt = "<|im_start|>user\n"; } + else if (has_minicpmv_projector == 4) { + system_prompt = "<|im_start|>user\n"; + } LOG_INF("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); @@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { user_prompt = "<|im_start|>user\n" + prompt; } + else if (has_minicpmv_projector == 4) { + user_prompt = "<|im_start|>user\n" + prompt; + } } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); @@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); } + else if (has_minicpmv_projector == 4) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } // generate the response @@ -308,7 +317,6 @@ int main(int argc, char ** argv) { const auto * tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior printf("%s", tmp);// mistral llava-1.6 if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index ea773742a..9b196757f 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) # with proper args = ap.parse_args() @@ -545,12 +545,19 @@ if args.use_f32: minicpmv_version = args.minicpmv_version emb_dim = 4096 +block_count = 26 if minicpmv_version == 1: emb_dim = 2304 + block_count = 26 elif minicpmv_version == 2: emb_dim = 4096 + block_count = 27 elif minicpmv_version == 3: emb_dim = 3584 + block_count = 27 +elif minicpmv_version == 4: + emb_dim = 3584 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config) if minicpmv_version == 3: vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 4: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None: fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True - minicpmv_version = 3 + minicpmv_version = 4 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -625,7 +635,6 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", 0) fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) if processor is not None: diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 748ff5c57..ba8211658 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) checkpoint = model.state_dict() # get a list of mm tensor names From a94f3b2727e97eb6c904006eb786960c069282bc Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 09:51:44 +0000 Subject: [PATCH 02/11] `common`: utils to split / join / repeat strings (from json converter) (#11342) * Factor string_join, string_split, string_repeat into common * json: refactor to surface a versatile builder * Update common.cpp --- common/common.cpp | 42 +++++++++++++ common/common.h | 4 ++ common/json-schema-to-grammar.cpp | 98 +++++++++++-------------------- common/json-schema-to-grammar.h | 10 +++- 4 files changed, 90 insertions(+), 64 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 727ab0a10..6dea8e3d2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -484,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } diff --git a/common/common.h b/common/common.h index 7c9d73ce1..571260372 100644 --- a/common/common.h +++ b/common/common.h @@ -429,6 +429,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dadc18c8b..4d426b6bd 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb); std::function _fetch_json; bool _dotall; std::map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -854,7 +809,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +860,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +974,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1036,10 +991,27 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); + return build_grammar([&](const llama_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb) { + SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); + llama_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 41623b346..4f43ab3a5 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,4 +5,12 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); + +struct llama_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +std::string build_grammar(const std::function & cb); From 96f405393461062450692430e4916809bf71c3c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Podiv=C3=ADn?= <66251151+jpodivin@users.noreply.github.com> Date: Wed, 22 Jan 2025 12:51:32 +0100 Subject: [PATCH 03/11] Adding logprobs to /v1/completions (#11344) Signed-off-by: Jiri Podivin --- examples/server/server.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5f08c4ecc..412908aa8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -267,6 +267,11 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + if (data.contains("lora")) { if (data.at("lora").is_array()) { params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); From c64d2becb13f2a7c0145b516a05f028d949a046b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 16:16:27 +0000 Subject: [PATCH 04/11] `minja`: sync at https://github.com/google/minja/commit/0f5f7f2b3770eb682fbc11763266d45204173686 (#11352) --- common/chat-template.hpp | 37 ++++++++++++++++++++++++++++--------- common/minja.hpp | 28 ++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index b4a90145c..42ee0b615 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -25,6 +25,7 @@ class chat_template { // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Most other templates (and OpenAI's API) expect the arguments object to be stringified. bool requires_object_arguments_ = false; + bool requires_typed_content_ = false; bool supports_system_role_ = true; bool supports_parallel_tool_calls_ = false; std::string source_; @@ -32,14 +33,14 @@ class chat_template { std::string eos_token_; std::shared_ptr template_root_; - std::string try_render( + std::string try_raw_render( const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context); + auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -60,7 +61,7 @@ class chat_template { supports_tools_ = source.find("tools") != std::string::npos; auto renders_string_arguments = - try_render({ + try_raw_render({ { {"role", "user"}, {"content", "Hey"} @@ -81,7 +82,7 @@ class chat_template { }, {}, false).find("{\"code\": \"print") != std::string::npos; if (!renders_string_arguments) { auto renders_object_arguments = - try_render({ + try_raw_render({ { {"role", "user"}, {"content", "Hey"} @@ -106,10 +107,13 @@ class chat_template { } supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; - supports_system_role_ = try_render({ + supports_system_role_ = try_raw_render({ {{"role", "system"}, {"content", ""}}, {{"role", "user"}, {"content", "Hey"}} }, {}, false).find("") != std::string::npos; + + requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos + && try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos; } const std::string & source() const { return source_; } @@ -122,19 +126,34 @@ class chat_template { const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool adjust_inputs = true) const { json actual_messages; // First, "fix" messages so they have a chance to be rendered correctly by the template - if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) { + if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) { actual_messages = json::array(); + auto add_message = [&](const json & msg) { + if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + actual_messages.push_back({ + {"role", msg.at("role")}, + {"content", {{ + {"type", "text"}, + {"text", msg.at("content")}, + }}}, + }); + } else { + actual_messages.push_back(msg); + } + }; + std::string pending_system; auto flush_sys = [&]() { if (!pending_system.empty()) { - actual_messages.push_back({ + add_message({ {"role", "user"}, {"content", pending_system}, }); @@ -217,7 +236,7 @@ class chat_template { } } } - actual_messages.push_back(message); + add_message(message); } flush_sys(); } else { diff --git a/common/minja.hpp b/common/minja.hpp index f0ee7a49a..80bdd4b41 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; static std::string typeToString(Type t) { switch (t) { @@ -712,6 +712,8 @@ public: case Type::EndMacro: return "endmacro"; case Type::Filter: return "filter"; case Type::EndFilter: return "endfilter"; + case Type::Generation: return "generation"; + case Type::EndGeneration: return "endgeneration"; } return "Unknown"; } @@ -788,6 +790,14 @@ struct EndForTemplateToken : public TemplateToken { EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} }; +struct GenerationTemplateToken : public TemplateToken { + GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} +}; + +struct EndGenerationTemplateToken : public TemplateToken { + EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} +}; + struct SetTemplateToken : public TemplateToken { std::string ns; std::vector var_names; @@ -2149,7 +2159,7 @@ private: static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); @@ -2229,6 +2239,12 @@ private: } else if (keyword == "endfor") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "generation") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endgeneration") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "set") { static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); @@ -2330,6 +2346,13 @@ private: throw unterminated(**start); } children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { + throw unterminated(**start); + } + // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). + children.emplace_back(std::move(body)); } else if (auto text_token = dynamic_cast(token.get())) { SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; @@ -2397,6 +2420,7 @@ private: || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) + || dynamic_cast(token.get()) || dynamic_cast(token.get())) { it--; // unconsume the token break; // exit the loop From 12c2bdf2de34f747d13b270fc9d3b52490bf194f Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Wed, 22 Jan 2025 17:44:40 +0100 Subject: [PATCH 05/11] server : fix draft context not being released (#11354) --- examples/server/server.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 412908aa8..4cfb3c9bb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1772,6 +1772,9 @@ struct server_context { // force F16 KV cache for the draft model for extra performance cparams_dft.type_k = GGML_TYPE_F16; cparams_dft.type_v = GGML_TYPE_F16; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); } chat_templates = common_chat_templates_from_model(model, params_base.chat_template); From 16d3df7ab0bb9def78047eab4d9b15393997f495 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 Jan 2025 19:44:26 +0200 Subject: [PATCH 06/11] readme : add plugin links (#11355) --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 784669ce1..97d028670 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123 +- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode +- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim +- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123 - Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669 - Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) From 6152129d05870cb38162c422c6ba80434e021e9f Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Wed, 22 Jan 2025 19:22:20 +0100 Subject: [PATCH 07/11] main : update README documentation for batch size (#11353) * main : update README documentation for batch size * fix formatting * minor --- examples/main/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/main/README.md b/examples/main/README.md index 17d80a622..46f92eb7a 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models. ### Batch Size -- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations. +- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`. -- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`. +- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`. ### Prompt Caching From 5245729e3317064959faefc5e5cbc63f4e9cfbea Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 23 Jan 2025 01:01:17 -0600 Subject: [PATCH 08/11] vulkan: fix diag_mask_inf (#11323) With robustbufferaccess disabled, this shader was showing OOB stores. There is a bounds check in the code, but the workgrouop dimensions were reversed vs CUDA and it was running the wrong number of threads. So fix the workgroup dimensions and disable robustness for this pipeline. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 78c2f5c45..eae4f0739 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2012,7 +2012,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b5..26d8bc22a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; From 1971adf55e3ebbfab83771d6174d37b77c352c71 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 23 Jan 2025 01:07:50 -0600 Subject: [PATCH 09/11] vulkan: sort shaders for more deterministic binary (#11315) Fixes #11306. --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8bcb64101..e9c6cb9d4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -17,13 +17,13 @@ #include #include #include +#include #include #include #ifdef _WIN32 #include #include // For _mkdir on Windows - #include // For std::replace on w64devkit #else #include #include @@ -502,6 +502,7 @@ void write_output_files() { fprintf(hdr, "#include \n\n"); fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { const std::string& name = pair.first; #ifdef _WIN32 From 955a6c2d91480954e90469517c4b91c4052c6a59 Mon Sep 17 00:00:00 2001 From: amd-dwang Date: Thu, 23 Jan 2025 15:14:28 +0800 Subject: [PATCH 10/11] Vulkan-run-test: fix mmq_wg_denoms (#11343) There should be a copy-and-paste error here. *mmq_wg_denoms should be used together with *warptile_mmq, instead of wg_denoms. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 88 ++++++++++++++-------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index eae4f0739..c325416d1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1673,31 +1673,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1707,31 +1707,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } #undef CREATE_MM2 From f211d1dc10332b7e89a4abd1041903fa63bfed27 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Thu, 23 Jan 2025 10:38:20 +0000 Subject: [PATCH 11/11] Treat hf.co/ prefix the same as hf:// (#11350) ollama uses hf.co/ to specify huggingface prefix, like RamaLama uses hf:// Treat them similarly. Signed-off-by: Eric Curtin --- examples/run/run.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index e567ad716..c710d4326 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -634,20 +634,20 @@ class LlamaData { return path.substr(pos + 1); } - int remove_proto(std::string & model_) { - const std::string::size_type pos = model_.find("://"); + int rm_until_substring(std::string & model_, const std::string & substring) { + const std::string::size_type pos = model_.find(substring); if (pos == std::string::npos) { return 1; } - model_ = model_.substr(pos + 3); // Skip past "://" + model_ = model_.substr(pos + substring.size()); // Skip past the substring return 0; } int resolve_model(std::string & model_) { int ret = 0; if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - remove_proto(model_); + rm_until_substring(model_, "://"); return ret; } @@ -656,13 +656,16 @@ class LlamaData { const std::vector headers = { "--header", "Accept: application/vnd.docker.distribution.manifest.v2+json" }; if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { - remove_proto(model_); + rm_until_substring(model_, "://"); + ret = huggingface_dl(model_, headers, bn); + } else if (string_starts_with(model_, "hf.co/")) { + rm_until_substring(model_, "hf.co/"); ret = huggingface_dl(model_, headers, bn); } else if (string_starts_with(model_, "ollama://")) { - remove_proto(model_); + rm_until_substring(model_, "://"); ret = ollama_dl(model_, headers, bn); } else if (string_starts_with(model_, "https://")) { - download(model_, headers, bn, true); + ret = download(model_, headers, bn, true); } else { ret = ollama_dl(model_, headers, bn); }