From 163d50adaf8897d8b734d701ff332de6be63d484 Mon Sep 17 00:00:00 2001 From: jukofyork <69222624+jukofyork@users.noreply.github.com> Date: Tue, 25 Jun 2024 21:47:40 +0100 Subject: [PATCH 1/6] fixes #7999 (adds control vectors to all `build_XXX()` functions in `llama.cpp` [needs testing] (#8060) * fixes #7999 The `build_command_r` forgot to add the control vector. * Fixes qwen2 too * Fixed all models' control vectors * Removed double calls to `cb(cur, "l_out", il)` * Moved control vector logic to llama_control_vector:apply_to() --- llama.cpp | 112 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 39 deletions(-) diff --git a/llama.cpp b/llama.cpp index 78a21008f..989c73149 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2368,13 +2368,21 @@ struct llama_control_vector { int32_t layer_start = -1; int32_t layer_end = -1; - ggml_tensor * tensor_for(int il) const { + struct ggml_tensor * tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } return tensors[il]; } + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + return cur; + } + ~llama_control_vector() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -8023,10 +8031,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8141,6 +8146,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8245,6 +8251,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8360,9 +8367,8 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8514,10 +8520,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8648,10 +8651,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -8757,8 +8757,12 @@ struct llm_build_context { cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -8846,6 +8850,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9141,8 +9146,12 @@ struct llm_build_context { cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -9276,6 +9285,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9424,6 +9434,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9536,6 +9547,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9647,6 +9659,7 @@ struct llm_build_context { cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9792,6 +9805,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -9912,11 +9926,11 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_output); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } @@ -10048,8 +10062,10 @@ struct llm_build_context { } cur = ggml_add(ctx0, residual, cur); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } @@ -10148,9 +10164,8 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, sa_out); - cb(cur, "l_out", il); - cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10256,8 +10271,12 @@ struct llm_build_context { cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -10363,8 +10382,12 @@ struct llm_build_context { cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -10476,6 +10499,7 @@ struct llm_build_context { cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10593,6 +10617,7 @@ struct llm_build_context { cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10734,6 +10759,7 @@ struct llm_build_context { cb(cur, "hidden_scaled_ffn", -1); cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10846,6 +10872,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, sa_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -10962,7 +10989,9 @@ struct llm_build_context { NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11111,6 +11140,7 @@ struct llm_build_context { // residual cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11252,6 +11282,7 @@ struct llm_build_context { // add together residual + FFN + self-attention cur = ggml_add(ctx0, cur, inpL); cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11387,10 +11418,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11504,8 +11532,12 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, inpL); cb(cur, "ffn_out", il); - inpL = ggml_add(ctx0, cur, attn_out); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } else { // attention and ffn are computed sequentially // x = x + attn(ln1(x)) @@ -11528,8 +11560,12 @@ struct llm_build_context { LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } } @@ -11656,10 +11692,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_out); cb(cur, "ffn_out", il); - ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - if (layer_dir != nullptr) { - cur = ggml_add(ctx0, cur, layer_dir); - } + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer @@ -11892,6 +11925,7 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); // input for next layer From 6777c544bdd8c5d9de3220d6e2557957bbbf2a4f Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 26 Jun 2024 01:45:58 +0100 Subject: [PATCH 2/6] `json`: fix additionalProperties, allow space after enum/const (#7840) * json: default additionalProperty to true * json: don't force additional props after normal properties! * json: allow space after enum/const * json: update pydantic example to set additionalProperties: false * json: prevent additional props to redefine a typed prop * port not_strings to python, add trailing space * fix not_strings & port to js+py * Update json-schema-to-grammar.cpp * fix _not_strings for substring overlaps * json: fix additionalProperties default, uncomment tests * json: add integ. test case for additionalProperties * json: nit: simplify condition * reformat grammar integ tests w/ R"""()""" strings where there's escapes * update # tokens in server test: consts can now have trailing space --- common/json-schema-to-grammar.cpp | 99 +++++- examples/json-schema-pydantic-example.py | 6 +- examples/json_schema_to_grammar.py | 76 ++++- .../server/public/json-schema-to-grammar.mjs | 89 ++++- examples/server/tests/features/server.feature | 2 +- tests/test-grammar-integration.cpp | 320 ++++++++---------- tests/test-json-schema-to-grammar.cpp | 150 ++++++-- 7 files changed, 497 insertions(+), 245 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 07d0e952d..b40821dad 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -614,6 +614,75 @@ private: return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); } + /* + Returns a rule that matches a JSON string that is none of the provided strings + + not_strings({"a"}) + -> ["] ( [a] char+ | [^"a] char* )? ["] space + not_strings({"and", "also"}) + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + */ + std::string _not_strings(const std::vector & strings) { + + struct TrieNode { + std::map children; + bool is_end_of_string; + + TrieNode() : is_end_of_string(false) {} + + void insert(const std::string & string) { + auto node = this; + for (char c : string) { + node = &node->children[c]; + } + node->is_end_of_string = true; + } + }; + + TrieNode trie; + for (const auto & s : strings) { + trie.insert(s); + } + + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + std::ostringstream out; + out << "[\"] ( "; + std::function visit = [&](const TrieNode & node) { + std::ostringstream rejects; + auto first = true; + for (const auto & kv : node.children) { + rejects << kv.first; + if (first) { + first = false; + } else { + out << " | "; + } + out << "[" << kv.first << "]"; + if (!kv.second.children.empty()) { + out << " ("; + visit(kv.second); + out << ")"; + } else if (kv.second.is_end_of_string) { + out << " " << char_rule << "+"; + } + } + if (!node.children.empty()) { + if (!first) { + out << " | "; + } + out << "[^\"" << rejects.str() << "] " << char_rule << "*"; + } + }; + visit(trie); + + out << " )"; + if (!trie.is_end_of_string) { + out << "?"; + } + out << " [\"] space"; + return out.str(); + } + std::string _resolve_ref(const std::string & ref) { std::string ref_name = ref.substr(ref.find_last_of('/') + 1); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { @@ -634,6 +703,7 @@ private: std::vector required_props; std::vector optional_props; std::unordered_map prop_kv_rule_names; + std::vector prop_names; for (const auto & kv : properties) { const auto &prop_name = kv.first; const auto &prop_schema = kv.second; @@ -648,11 +718,18 @@ private: } else { optional_props.push_back(prop_name); } + prop_names.push_back(prop_name); } - if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get())) { + if (!(additional_properties.is_boolean() && !additional_properties.get())) { std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; - std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value"); - std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule); + std::string value_rule = + additional_properties.is_object() ? visit(additional_properties, sub_name + "-value") + : _add_primitive("value", PRIMITIVE_RULES.at("value")); + + auto key_rule = + prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string")) + : _add_rule(sub_name + "-k", _not_strings(prop_names)); + std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule); prop_kv_rule_names["*"] = kv_rule; optional_props.push_back("*"); } @@ -678,15 +755,11 @@ private: } std::string k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; - if (k == "*") { - res = _add_rule( - name + (name.empty() ? "" : "-") + "additional-kvs", - kv_rule_name + " ( \",\" space " + kv_rule_name + " )*" - ); - } else if (first_is_optional) { - res = "( \",\" space " + kv_rule_name + " )?"; + std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; + if (first_is_optional) { + res = comma_ref + (k == "*" ? "*" : "?"); } else { - res = kv_rule_name; + res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : ""); } if (ks.size() > 1) { res += " " + _add_rule( @@ -824,13 +897,13 @@ public: } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } else if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"])); + return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); } else if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | ")); + return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { diff --git a/examples/json-schema-pydantic-example.py b/examples/json-schema-pydantic-example.py index 2240188cd..2a24f8118 100644 --- a/examples/json-schema-pydantic-example.py +++ b/examples/json-schema-pydantic-example.py @@ -3,7 +3,7 @@ #! pip install pydantic #! python json-schema-pydantic-example.py -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, Extra, TypeAdapter from annotated_types import MinLen from typing import Annotated, List, Optional import json, requests @@ -50,12 +50,16 @@ else: if __name__ == '__main__': class QAPair(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema question: str concise_answer: str justification: str stars: Annotated[int, Field(ge=1, le=5)] class PyramidalSummary(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema title: str summary: str question_answers: Annotated[List[QAPair], MinLen(2)] diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 86500a8c3..3f3132f88 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -4,8 +4,7 @@ import itertools import json import re import sys -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union - +from typing import Any, List, Optional, Set, Tuple, Union def _build_repetition(item_rule, min_items, max_items, separator_rule=None): @@ -276,6 +275,51 @@ class SchemaConverter: return ''.join(('(', *recurse(0), ')')) + def _not_strings(self, strings): + class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_string = False + + def insert(self, string): + node = self + for c in string: + node = node.children.setdefault(c, TrieNode()) + node.is_end_of_string = True + + trie = TrieNode() + for s in strings: + trie.insert(s) + + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + out = ['["] ( '] + + def visit(node): + rejects = [] + first = True + for c in sorted(node.children.keys()): + child = node.children[c] + rejects.append(c) + if first: + first = False + else: + out.append(' | ') + out.append(f'[{c}]') + if child.children: + out.append(f' (') + visit(child) + out.append(')') + elif child.is_end_of_string: + out.append(f' {char_rule}+') + if node.children: + if not first: + out.append(' | ') + out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) + + out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + return ''.join(out) + def _add_rule(self, name, rule): esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: @@ -524,10 +568,10 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') elif 'enum' in schema: - rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ @@ -632,7 +676,7 @@ class SchemaConverter: self._add_primitive(dep, dep_rule) return n - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] @@ -647,12 +691,16 @@ class SchemaConverter: required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] - if additional_properties == True or isinstance(additional_properties, dict): + if additional_properties != False: sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ + self._add_primitive('value', PRIMITIVE_RULES['value']) + key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ + else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + prop_kv_rule_names["*"] = self._add_rule( f'{sub_name}-kv', - self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + f'{key_rule} ":" space {value_rule}' ) optional_props.append("*") @@ -667,15 +715,11 @@ class SchemaConverter: def get_recursive_refs(ks, first_is_optional): [k, *rest] = ks kv_rule_name = prop_kv_rule_names[k] - if k == '*': - res = self._add_rule( - f'{name}{"-" if name else ""}additional-kvs', - f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' - ) - elif first_is_optional: - res = f'( "," space {kv_rule_name} )?' + comma_ref = f'( "," space {kv_rule_name} )' + if first_is_optional: + res = comma_ref + ('*' if k == '*' else '?') else: - res = kv_rule_name + res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') if len(rest) > 0: res += ' ' + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index f340f94bd..02015bbd4 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -532,6 +532,64 @@ export class SchemaConverter { return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space") } + _notStrings(strings) { + class TrieNode { + constructor() { + this.children = {}; + this.isEndOfString = false; + } + + insert(str) { + let node = this; + for (const c of str) { + node = node.children[c] = node.children[c] || new TrieNode(); + } + node.isEndOfString = true; + } + } + + const trie = new TrieNode(); + for (const s of strings) { + trie.insert(s); + } + + const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']); + const out = ['["] ( ']; + + const visit = (node) => { + const rejects = []; + let first = true; + for (const c of Object.keys(node.children).sort()) { + const child = node.children[c]; + rejects.push(c); + if (first) { + first = false; + } else { + out.push(' | '); + } + out.push(`[${c}]`); + if (Object.keys(child.children).length > 0) { + out.push(' ('); + visit(child); + out.push(')'); + } else if (child.isEndOfString) { + out.push(` ${charRuleName}+`); + } + } + if (Object.keys(node.children).length > 0) { + if (!first) { + out.push(' | '); + } + out.push(`[^"${rejects.join('')}] ${charRuleName}*`); + } + }; + + visit(trie); + + out.push(` )${trie.isEndOfString ? '' : '?'} ["] space`); + return out.join(''); + } + _resolveRef(ref) { let refName = ref.split('/').pop(); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { @@ -560,9 +618,9 @@ export class SchemaConverter { } else if (Array.isArray(schemaType)) { return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t })))); } else if ('const' in schema) { - return this._addRule(ruleName, this._generateConstantRule(schema.const)); + return this._addRule(ruleName, this._generateConstantRule(schema.const) + ' space'); } else if ('enum' in schema) { - const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | '); + const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; return this._addRule(ruleName, rule); } else if ((schemaType === undefined || schemaType === 'object') && ('properties' in schema || @@ -599,7 +657,7 @@ export class SchemaConverter { } } - return this._addRule(ruleName, this._buildObjectRule(properties, required, name, /* additionalProperties= */ false)); + return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { const items = schema.items ?? schema.prefixItems; if (Array.isArray(items)) { @@ -693,12 +751,19 @@ export class SchemaConverter { const requiredProps = sortedProps.filter(k => required.has(k)); const optionalProps = sortedProps.filter(k => !required.has(k)); - if (typeof additionalProperties === 'object' || additionalProperties === true) { + if (additionalProperties !== false) { const subName = `${name ?? ''}${name ? '-' : ''}additional`; - const valueRule = this.visit(additionalProperties === true ? {} : additionalProperties, `${subName}-value`); + const valueRule = + additionalProperties != null && typeof additionalProperties === 'object' ? this.visit(additionalProperties, `${subName}-value`) + : this._addPrimitive('value', PRIMITIVE_RULES['value']); + + const key_rule = + sortedProps.length === 0 ? this._addPrimitive('string', PRIMITIVE_RULES['string']) + : this._addRule(`${subName}-k`, this._notStrings(sortedProps)); + propKvRuleNames['*'] = this._addRule( `${subName}-kv`, - `${this._addPrimitive('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`); + `${key_rule} ":" space ${valueRule}`); optionalProps.push('*'); } @@ -715,15 +780,11 @@ export class SchemaConverter { const [k, ...rest] = ks; const kvRuleName = propKvRuleNames[k]; let res; - if (k === '*') { - res = this._addRule( - `${name ?? ''}${name ? '-' : ''}additional-kvs`, - `${kvRuleName} ( "," space ` + kvRuleName + ` )*` - ) - } else if (firstIsOptional) { - res = `( "," space ${kvRuleName} )?`; + const commaRef = `( "," space ${kvRuleName} )`; + if (firstIsOptional) { + res = commaRef + (k === '*' ? '*' : '?'); } else { - res = kvRuleName; + res = kvRuleName + (k === '*' ? ' ' + commaRef + '*' : ''); } if (rest.length > 0) { res += ' ' + this._addRule( diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index d21c09135..b55971454 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -82,7 +82,7 @@ Feature: llama.cpp server Examples: Prompts | response_format | n_predicted | re_content | - | {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" | + | {"type": "json_object", "schema": {"const": "42"}} | 6 | "42" | | {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] | | {"type": "json_object"} | 10 | \{ " Jacky. | diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 5750d362a..23ef8324c 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -15,8 +15,6 @@ using json = nlohmann::ordered_json; -//#define INCLUDE_FAILING_TESTS 1 - static llama_grammar* build_grammar(const std::string & grammar_str) { auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); @@ -754,7 +752,7 @@ static void test_json_schema() { )""", // Passing strings { - "{}", + R"""({})""", R"""({"foo": "bar"})""", }, // Failing strings @@ -762,7 +760,7 @@ static void test_json_schema() { "", "[]", "null", - "\"\"", + R"""("")""", "true", } ); @@ -770,16 +768,14 @@ static void test_json_schema() { test_schema( "exotic formats (list)", // Schema - R"""( - { + R"""({ "items": [ { "format": "date" }, { "format": "uuid" }, { "format": "time" }, { "format": "date-time" } ] - } - )""", + })""", // Passing strings { // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? @@ -798,125 +794,113 @@ static void test_json_schema() { test_schema( "string", // Schema - R"""( - { - "type": "string" - } - )""", + R"""({ + "type": "string" + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"\"", + R"""("foo")""", + R"""("bar")""", + R"""("")""", }, // Failing strings { - "{}", - "\"foo\": \"bar\"", + R"""({})""", + R"""("foo": "bar")""", } ); test_schema( "string w/ min length 1", // Schema - R"""( - { - "type": "string", - "minLength": 1 - } - )""", + R"""({ + "type": "string", + "minLength": 1 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", + R"""("foo")""", + R"""("bar")""", }, // Failing strings { - "\"\"", - "{}", - "\"foo\": \"bar\"", + R"""("")""", + R"""({})""", + R"""("foo": "bar")""", } ); test_schema( "string w/ min length 3", // Schema - R"""( - { + R"""({ "type": "string", "minLength": 3 - } - )""", + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"foobar\"", + R"""("foo")""", + R"""("bar")""", + R"""("foobar")""", }, // Failing strings { - "\"\"", - "\"f\"", - "\"fo\"", + R"""("")""", + R"""("f")""", + R"""("fo")""", } ); test_schema( "string w/ max length", // Schema - R"""( - { - "type": "string", - "maxLength": 3 - } - )""", + R"""({ + "type": "string", + "maxLength": 3 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"\"", - "\"f\"", - "\"fo\"", + R"""("foo")""", + R"""("bar")""", + R"""("")""", + R"""("f")""", + R"""("fo")""", }, // Failing strings { - "\"foobar\"", + R"""("foobar")""", } ); test_schema( "string w/ min & max length", // Schema - R"""( - { - "type": "string", - "minLength": 1, - "maxLength": 4 - } - )""", + R"""({ + "type": "string", + "minLength": 1, + "maxLength": 4 + })""", // Passing strings { - "\"foo\"", - "\"bar\"", - "\"f\"", - "\"barf\"", + R"""("foo")""", + R"""("bar")""", + R"""("f")""", + R"""("barf")""", }, // Failing strings { - "\"\"", - "\"barfo\"", - "\"foobar\"", + R"""("")""", + R"""("barfo")""", + R"""("foobar")""", } ); test_schema( "boolean", // Schema - R"""( - { - "type": "boolean" - } - )""", + R"""({ + "type": "boolean" + })""", // Passing strings { "true", @@ -924,122 +908,112 @@ static void test_json_schema() { }, // Failing strings { - "\"\"", - "\"true\"", - "True", - "FALSE", + R"""("")""", + R"""("true")""", + R"""(True)""", + R"""(FALSE)""", } ); test_schema( "integer", // Schema - R"""( - { - "type": "integer" - } - )""", + R"""({ + "type": "integer" + })""", // Passing strings { - "0", - "12345", - "1234567890123456" + R"""(0)""", + R"""(12345)""", + R"""(1234567890123456)""", }, // Failing strings { - "", - "01", - "007", - "12345678901234567" + R"""()""", + R"""(01)""", + R"""(007)""", + R"""(12345678901234567 )""", } ); test_schema( "string const", // Schema - R"""( - { - "const": "foo" - } - )""", + R"""({ + "const": "foo" + })""", // Passing strings { - "\"foo\"", + R"""("foo")""", }, // Failing strings { - "foo", - "\"bar\"", + R"""(foo)""", + R"""("bar")""", } ); test_schema( "non-string const", // Schema - R"""( - { - "const": true - } - )""", + R"""({ + "const": true + })""", // Passing strings { - "true", + R"""(true)""", }, // Failing strings { - "", - "foo", - "\"true\"", + R"""()""", + R"""(foo)""", + R"""("true")""", } ); test_schema( "non-string const", // Schema - R"""( - { - "enum": ["red", "amber", "green", null, 42, ["foo"]] - } - )""", + R"""({ + "enum": ["red", "amber", "green", null, 42, ["foo"]] + })""", // Passing strings { - "\"red\"", - "null", - "42", - "[\"foo\"]", + R"""("red")""", + R"""(null)""", + R"""(42)""", + R"""(["foo"])""", }, // Failing strings { - "", - "420", - "true", - "foo", + R"""()""", + R"""(420)""", + R"""(true)""", + R"""(foo)""", } ); test_schema( "min+max items", // Schema - R"""( - { - "items": { - "type": ["number", "integer"] - }, - "minItems": 3, - "maxItems": 5 - } - )""", + R"""({ + "items": { + "type": ["number", "integer"] + }, + "minItems": 3, + "maxItems": 5 + })""", // Passing strings { - "[1, 2, 3]", - "[1, 2, 3, 4]", - "[1, 2, 3, 4, 5]", + R"""([1, 2, 3])""", + R"""([1, 2, 3, 4])""", + R"""([1, 2, 3, 4, 5])""", }, // Failing strings { - "[1, 2]", - "[1, 2, 3, 4, 5, 6]", - "1" + R"""([1, 2])""", + R"""([1, 2, 3, 4, 5, 6])""", + R"""(1)""", } ); @@ -1047,16 +1021,14 @@ static void test_json_schema() { test_schema( "object properties", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, "street_name": { "type": "string" }, "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } } - } - )""", + })""", // Passing strings { R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", @@ -1066,12 +1038,8 @@ static void test_json_schema() { // "By extension, even an empty object is valid" R"""({})""", // "By default, providing additional properties is valid" -#ifdef INCLUDE_FAILING_TESTS - // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -1084,13 +1052,35 @@ static void test_json_schema() { } ); + test_schema( + "additional properties can't override other properties", + R"""({ + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"} + }, + "additionalProperties": true + })""", + // Passing strings + { + R"""({"a": 42})""", + R"""({"c": ""})""", + R"""({"a": 42, "c": ""})""", + R"""({"a_": ""})""", + }, + // Failing strings + { + R"""()""", + R"""({"a": ""})""", + R"""({"a": "", "b": ""})""", + } + ); // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) test_schema( "object properties, additionalProperties: true", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, @@ -1098,26 +1088,18 @@ static void test_json_schema() { "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } }, "additionalProperties": true - } - )""", + })""", // Passing strings { // "By extension, even an empty object is valid" R"""({})""", -#ifdef INCLUDE_FAILING_TESTS - // TODO: Following line should pass and doesn't R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""", // "By default, leaving out properties is valid" - // TODO: Following line should pass and doesn't R"""({ "street_name": "Pennsylvania" })""", - // TODO: Following line should pass and doesn't R"""({ "number": 1600, "street_name": "Pennsylvania" })""", // "By default, providing additional properties is valid" - // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -1132,8 +1114,7 @@ static void test_json_schema() { test_schema( "required + optional props each in original order", // Schema - R"""( - { + R"""({ "type": "object", "properties": { "number": { "type": "number" }, @@ -1141,18 +1122,15 @@ static void test_json_schema() { "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } }, "additionalProperties": false - } - )""", + })""", // Passing strings { R"""({ "street_name": "Pennsylvania" })""", R"""({ "number": 1600, "street_type":"Avenue"})""", R"""({ "number": 1600, "street_name": "Pennsylvania" })""", R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", -#ifdef INCLUDE_FAILING_TESTS - // TODO: Spaces should be permitted around enum values, but currently they fail to pass. + // Spaces are permitted around enum values R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", -#endif }, // Failing strings { @@ -1166,18 +1144,16 @@ static void test_json_schema() { test_schema( "required + optional props each in original order", // Schema - R"""( - { - "properties": { - "b": {"type": "string"}, - "a": {"type": "string"}, - "d": {"type": "string"}, - "c": {"type": "string"} - }, - "required": ["a", "b"], - "additionalProperties": false - } - )""", + R"""({ + "properties": { + "b": {"type": "string"}, + "a": {"type": "string"}, + "d": {"type": "string"}, + "c": {"type": "string"} + }, + "required": ["a", "b"], + "additionalProperties": false + })""", // Passing strings { R"""({"b": "foo", "a": "bar"})""", @@ -1197,8 +1173,7 @@ static void test_json_schema() { test_schema( "required props", // Schema - R"""( - { + R"""({ "$schema": "https://json-schema.org/draft/2020-12/schema", "$id": "https://example.com/product.schema.json", "title": "Product", @@ -1244,8 +1219,7 @@ static void test_json_schema() { } }, "required": [ "productId", "productName", "price" ] - } - )""", + })""", // Passing strings { R"""({"productId": 1, "productName": "A green door", "price": 12.50})""", diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 2e591bd71..1e69cb6ef 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -473,7 +473,7 @@ static void test_all(const std::string & lang, std::function Date: Wed, 26 Jun 2024 01:46:35 +0100 Subject: [PATCH 3/6] `json`: better support for "type" unions (e.g. nullable arrays w/ typed items) (#7863) * json: better suport for "type" arrays (e.g. `{"type": ["array", "null"], "items": {"type": "string"}}`) * json: add test for type: [array, null] fix * update tests --- common/json-schema-to-grammar.cpp | 4 ++- examples/json_schema_to_grammar.py | 2 +- .../server/public/json-schema-to-grammar.mjs | 2 +- tests/test-grammar-integration.cpp | 25 +++++++++++++++ tests/test-json-schema-to-grammar.cpp | 32 +++++++++++++++++++ 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index b40821dad..2f233e2e7 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -893,7 +893,9 @@ public: } else if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { - schema_types.push_back({{"type", t}}); + json schema_copy(schema); + schema_copy["type"] = t; + schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } else if (schema.contains("const")) { diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 3f3132f88..92f6e3d47 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -565,7 +565,7 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) elif 'const' in schema: return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index 02015bbd4..06d76edde 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -616,7 +616,7 @@ export class SchemaConverter { } else if (schema.oneOf || schema.anyOf) { return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf)); } else if (Array.isArray(schemaType)) { - return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t })))); + return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({...schema, type: t})))); } else if ('const' in schema) { return this._addRule(ruleName, this._generateConstantRule(schema.const) + ' space'); } else if ('enum' in schema) { diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 23ef8324c..0e21dc795 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -993,6 +993,31 @@ static void test_json_schema() { } ); + test_schema( + "", + // Schema + R"""( + { + "type": ["array", "null"], + "items": { "type": "string" } + } + )""", + // Passing strings + { + "null", + "[]", + "[\"123\"]", + "[\"foo\", \"bar\"]", + }, + // Failing strings + { + "", + "[123]", + "\"foo\"", + "[\"foo\", 42]", + } + ); + test_schema( "min+max items", // Schema diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 1e69cb6ef..3aaa11833 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -502,6 +502,38 @@ static void test_all(const std::string & lang, std::function Date: Wed, 26 Jun 2024 14:27:46 +0800 Subject: [PATCH 4/6] llama : extend llm_build_ffn() to support _scale tensors (#8103) --- llama.cpp | 255 +++++++++++++++++++++++++++++------------------------- 1 file changed, 135 insertions(+), 120 deletions(-) diff --git a/llama.cpp b/llama.cpp index 989c73149..f78594a6f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7212,10 +7212,13 @@ static struct ggml_tensor * llm_build_ffn( struct ggml_tensor * cur, struct ggml_tensor * up, struct ggml_tensor * up_b, + struct ggml_tensor * up_s, struct ggml_tensor * gate, struct ggml_tensor * gate_b, + struct ggml_tensor * gate_s, struct ggml_tensor * down, struct ggml_tensor * down_b, + struct ggml_tensor * down_s, struct ggml_tensor * act_scales, llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, @@ -7229,6 +7232,11 @@ static struct ggml_tensor * llm_build_ffn( cb(tmp, "ffn_up_b", il); } + if (up_s) { + tmp = ggml_mul(ctx, tmp, up_s); + cb(tmp, "ffn_up_s", il); + } + if (gate) { switch (type_gate) { case LLM_FFN_SEQ: @@ -7247,6 +7255,12 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_add(ctx, cur, gate_b); cb(cur, "ffn_gate_b", il); } + + if (gate_s) { + cur = ggml_mul(ctx, cur, gate_s); + cb(cur, "ffn_gate_s", il); + } + } else { cur = tmp; } @@ -7286,7 +7300,10 @@ static struct ggml_tensor * llm_build_ffn( cb(cur, "ffn_gate_par", il); } - cur = ggml_mul_mat(ctx, down, cur); + if (down) { + cur = ggml_mul_mat(ctx, down, cur); + } + if (down_b) { cb(cur, "ffn_down", il); } @@ -7295,6 +7312,11 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_add(ctx, cur, down_b); } + if (down_s) { + cur = ggml_mul(ctx, cur, down_s); + cb(cur, "ffn_down_s", il); + } + return cur; } @@ -8003,9 +8025,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -8137,9 +8159,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -8242,9 +8264,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -8358,9 +8380,9 @@ struct llm_build_context { // feed forward { cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result - model.layers[il].ffn_up, NULL, - NULL, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -8749,9 +8771,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -8841,9 +8863,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -9026,23 +9048,23 @@ struct llm_build_context { // feed-forward network if (model.arch == LLM_ARCH_BERT) { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); } @@ -9138,9 +9160,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -9276,9 +9298,9 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, model.layers[il].ffn_act, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -9425,9 +9447,9 @@ struct llm_build_context { cur = inpSA; } cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -9538,9 +9560,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -9651,9 +9673,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -9788,9 +9810,9 @@ struct llm_build_context { cb(cur_gate, "ffn_shexp_gate", il); ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up_shexp, NULL, - model.layers[il].ffn_gate_shexp, NULL, - model.layers[il].ffn_down_shexp, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur_ffn, "ffn_shexp", il); @@ -9917,9 +9939,9 @@ struct llm_build_context { // FF { ffn_output = llm_build_ffn(ctx0, attn_norm_output, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(ffn_output, "ffn_out", il); @@ -10155,9 +10177,9 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10263,9 +10285,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -10374,9 +10396,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -10491,9 +10513,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10609,9 +10631,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10746,9 +10768,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10863,9 +10885,9 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -10983,9 +11005,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -11271,9 +11293,9 @@ struct llm_build_context { // feed-forward network { cur = llm_build_ffn(ctx0, ffn_inp, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11408,9 +11430,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11522,9 +11544,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -11553,9 +11575,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -11662,9 +11684,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11884,9 +11906,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -11912,9 +11934,9 @@ struct llm_build_context { // FFN shared expert { ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up_shexp, NULL, - model.layers[il].ffn_gate_shexp, NULL, - model.layers[il].ffn_down_shexp, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(ffn_shexp, "ffn_shexp", il); @@ -12017,7 +12039,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - nullptr, nullptr, + NULL, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cur = llm_build_norm(ctx0, cur, hparams, @@ -12044,35 +12066,28 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward forward - if (model.layers[il].ffn_gate_inp == nullptr) { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); - cb(tmp, "ffn_up", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_sub_out", il); - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); - cb(cur, "ffn_gate", il); + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); - cur = ggml_silu(ctx0, cur); - cb(cur, "ffn_silu", il); + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + cb(cur, "ffn_down", il); - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "ffn_gate_par", il); - - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].ffn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_sub_norm", il); - - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); - cb(cur, "ffn_down", il); - } cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); From c8771ab5f89387cdd7d9a8a69280dac46b45e02f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 26 Jun 2024 08:28:02 +0200 Subject: [PATCH 5/6] CUDA: fix misaligned shared memory read (#8123) --- ggml-cuda/mma.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda/mma.cuh b/ggml-cuda/mma.cuh index 0301a52f9..5d87dd8e6 100644 --- a/ggml-cuda/mma.cuh +++ b/ggml-cuda/mma.cuh @@ -23,7 +23,7 @@ struct mma_int_A_I16K4 { __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { #if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + const int * xs = xs0 + (threadIdx.x%I)*stride; asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "+r"(x[0]), "+r"(x[1]) : "l"(xs)); From 88540445615e77a0177fcca43aaa8e9d8eea6864 Mon Sep 17 00:00:00 2001 From: Isaac McFadyen Date: Wed, 26 Jun 2024 02:29:28 -0400 Subject: [PATCH 6/6] Clarify default MMQ for CUDA and LLAMA_CUDA_FORCE_MMQ flag (#8115) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add message about int8 support * Add suggestions from review Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a54ee3951..95d970d83 100644 --- a/README.md +++ b/README.md @@ -511,7 +511,7 @@ Building the program with BLAS support may lead to some performance improvements | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | - | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |