mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
json-schema-to-grammar : fix order of props + non-str const/enum (#6232)
* json: ordered json in server/schema converter to respect orig order * json: ws nits * json: support non-string const / enums
This commit is contained in:
parent
2f0e81e053
commit
72114edf06
@ -9,7 +9,7 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
const std::string SPACE_RULE = "\" \"?";
|
const std::string SPACE_RULE = "\" \"?";
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ static std::string replacePattern(const std::string & input, const std::regex &
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::string format_literal(const std::string & literal) {
|
static std::string format_literal(const std::string & literal) {
|
||||||
std::string escaped = replacePattern(json(literal).dump(), GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
|
std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
|
||||||
char c = match.str()[0];
|
char c = match.str()[0];
|
||||||
return GRAMMAR_LITERAL_ESCAPES.at(c);
|
return GRAMMAR_LITERAL_ESCAPES.at(c);
|
||||||
});
|
});
|
||||||
@ -137,7 +137,7 @@ private:
|
|||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
std::map<std::string, std::string> _rules;
|
std::map<std::string, std::string> _rules;
|
||||||
std::unordered_map<std::string, nlohmann::json> _refs;
|
std::unordered_map<std::string, json> _refs;
|
||||||
std::unordered_set<std::string> _refs_being_resolved;
|
std::unordered_set<std::string> _refs_being_resolved;
|
||||||
std::vector<std::string> _errors;
|
std::vector<std::string> _errors;
|
||||||
std::vector<std::string> _warnings;
|
std::vector<std::string> _warnings;
|
||||||
@ -413,7 +413,7 @@ private:
|
|||||||
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
||||||
prop_kv_rule_names[prop_name] = _add_rule(
|
prop_kv_rule_names[prop_name] = _add_rule(
|
||||||
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
||||||
format_literal(prop_name) + " space \":\" space " + prop_rule_name
|
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
|
||||||
);
|
);
|
||||||
if (required.find(prop_name) != required.end()) {
|
if (required.find(prop_name) != required.end()) {
|
||||||
required_props.push_back(prop_name);
|
required_props.push_back(prop_name);
|
||||||
@ -495,7 +495,7 @@ public:
|
|||||||
_rules["space"] = SPACE_RULE;
|
_rules["space"] = SPACE_RULE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resolve_refs(nlohmann::json & schema, const std::string & url) {
|
void resolve_refs(json & schema, const std::string & url) {
|
||||||
/*
|
/*
|
||||||
* Resolves all $ref fields in the given schema, fetching any remote schemas,
|
* Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
* replacing each $ref with absolute reference URL and populates _refs with the
|
* replacing each $ref with absolute reference URL and populates _refs with the
|
||||||
@ -557,11 +557,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string _generate_constant_rule(const json & value) {
|
std::string _generate_constant_rule(const json & value) {
|
||||||
if (!value.is_string()) {
|
return format_literal(value.dump());
|
||||||
_errors.push_back("Only std::string constants are supported, got " + value.dump());
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return format_literal(value.get<std::string>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string visit(const json & schema, const std::string & name) {
|
std::string visit(const json & schema, const std::string & name) {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::json& schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
||||||
|
@ -61,7 +61,7 @@ class SchemaConverter:
|
|||||||
|
|
||||||
def _format_literal(self, literal):
|
def _format_literal(self, literal):
|
||||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||||
)
|
)
|
||||||
return f'"{escaped}"'
|
return f'"{escaped}"'
|
||||||
|
|
||||||
@ -308,8 +308,7 @@ class SchemaConverter:
|
|||||||
return ref_name
|
return ref_name
|
||||||
|
|
||||||
def _generate_constant_rule(self, value):
|
def _generate_constant_rule(self, value):
|
||||||
assert isinstance(value, str), f'Only string constants are supported, got {value}'
|
return self._format_literal(json.dumps(value))
|
||||||
return self._format_literal(value)
|
|
||||||
|
|
||||||
def visit(self, schema, name):
|
def visit(self, schema, name):
|
||||||
schema_type = schema.get('type')
|
schema_type = schema.get('type')
|
||||||
@ -428,7 +427,7 @@ class SchemaConverter:
|
|||||||
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||||
prop_kv_rule_names[prop_name] = self._add_rule(
|
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||||
f'{name}{"-" if name else ""}{prop_name}-kv',
|
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||||
fr'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
|
||||||
)
|
)
|
||||||
required_props = [k for k in sorted_props if k in required]
|
required_props = [k for k in sorted_props if k in required]
|
||||||
optional_props = [k for k in sorted_props if k not in required]
|
optional_props = [k for k in sorted_props if k not in required]
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -48,7 +48,7 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_formatLiteral(literal) {
|
_formatLiteral(literal) {
|
||||||
const escaped = JSON.stringify(literal).replace(
|
const escaped = literal.replace(
|
||||||
GRAMMAR_LITERAL_ESCAPE_RE,
|
GRAMMAR_LITERAL_ESCAPE_RE,
|
||||||
m => GRAMMAR_LITERAL_ESCAPES[m]
|
m => GRAMMAR_LITERAL_ESCAPES[m]
|
||||||
);
|
);
|
||||||
@ -327,10 +327,7 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_generateConstantRule(value) {
|
_generateConstantRule(value) {
|
||||||
if (typeof value !== 'string') {
|
return this._formatLiteral(JSON.stringify(value));
|
||||||
throw new Error('Only string constants are supported, got ' + JSON.stringify(value));
|
|
||||||
}
|
|
||||||
return this._formatLiteral(value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
visit(schema, name) {
|
visit(schema, name) {
|
||||||
@ -346,9 +343,6 @@ export class SchemaConverter {
|
|||||||
} else if (Array.isArray(schemaType)) {
|
} else if (Array.isArray(schemaType)) {
|
||||||
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
|
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
|
||||||
} else if ('const' in schema) {
|
} else if ('const' in schema) {
|
||||||
if (typeof schema.const !== 'string') {
|
|
||||||
throw new Error('Only string constants are supported, got ' + JSON.stringify(schema.const));
|
|
||||||
}
|
|
||||||
return this._addRule(ruleName, this._generateConstantRule(schema.const));
|
return this._addRule(ruleName, this._generateConstantRule(schema.const));
|
||||||
} else if ('enum' in schema) {
|
} else if ('enum' in schema) {
|
||||||
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
|
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
|
||||||
@ -457,7 +451,7 @@ export class SchemaConverter {
|
|||||||
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
|
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
|
||||||
propKvRuleNames[propName] = this._addRule(
|
propKvRuleNames[propName] = this._addRule(
|
||||||
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
|
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
|
||||||
`${this._formatLiteral(propName)} space ":" space ${propRuleName}`
|
`${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const requiredProps = sortedProps.filter(k => required.has(k));
|
const requiredProps = sortedProps.filter(k => required.has(k));
|
||||||
|
@ -30,7 +30,7 @@
|
|||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
bool server_verbose = false;
|
bool server_verbose = false;
|
||||||
bool server_log_json = true;
|
bool server_log_json = true;
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
|
@ -90,7 +90,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
|
|
||||||
test({
|
test({
|
||||||
FAILURE,
|
FAILURE,
|
||||||
"invalid type type",
|
"invalid type",
|
||||||
R"""({
|
R"""({
|
||||||
"type": 123
|
"type": 123
|
||||||
})""",
|
})""",
|
||||||
@ -193,21 +193,27 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
});
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
FAILURE,
|
SUCCESS,
|
||||||
"non-string const",
|
"non-string const",
|
||||||
R"""({
|
R"""({
|
||||||
"const": 123
|
"const": 123
|
||||||
})""",
|
})""",
|
||||||
""
|
R"""(
|
||||||
|
root ::= "123"
|
||||||
|
space ::= " "?
|
||||||
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
FAILURE,
|
SUCCESS,
|
||||||
"non-string enum",
|
"non-string enum",
|
||||||
R"""({
|
R"""({
|
||||||
"enum": [123]
|
"enum": ["red", "amber", "green", null, 42, ["foo"]]
|
||||||
})""",
|
})""",
|
||||||
""
|
R"""(
|
||||||
|
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
|
||||||
|
space ::= " "?
|
||||||
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
@ -378,20 +384,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"required props",
|
"required props in original order",
|
||||||
R"""({
|
R"""({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {
|
"b": {"type": "string"},
|
||||||
"type": "string"
|
"c": {"type": "string"},
|
||||||
},
|
"a": {"type": "string"}
|
||||||
"b": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"a",
|
"a",
|
||||||
"b"
|
"b",
|
||||||
|
"c"
|
||||||
],
|
],
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"definitions": {}
|
"definitions": {}
|
||||||
@ -399,7 +403,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
R"""(
|
R"""(
|
||||||
a-kv ::= "\"a\"" space ":" space string
|
a-kv ::= "\"a\"" space ":" space string
|
||||||
b-kv ::= "\"b\"" space ":" space string
|
b-kv ::= "\"b\"" space ":" space string
|
||||||
root ::= "{" space a-kv "," space b-kv "}" space
|
c-kv ::= "\"c\"" space ":" space string
|
||||||
|
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
|
||||||
space ::= " "?
|
space ::= " "?
|
||||||
string ::= "\"" (
|
string ::= "\"" (
|
||||||
[^"\\] |
|
[^"\\] |
|
||||||
@ -458,13 +463,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"required + optional props",
|
"required + optional props each in original order",
|
||||||
R"""({
|
R"""({
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"type": "string"},
|
|
||||||
"b": {"type": "string"},
|
"b": {"type": "string"},
|
||||||
"c": {"type": "string"},
|
"a": {"type": "string"},
|
||||||
"d": {"type": "string"}
|
"d": {"type": "string"},
|
||||||
|
"c": {"type": "string"}
|
||||||
},
|
},
|
||||||
"required": ["a", "b"],
|
"required": ["a", "b"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
@ -473,9 +478,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
a-kv ::= "\"a\"" space ":" space string
|
a-kv ::= "\"a\"" space ":" space string
|
||||||
b-kv ::= "\"b\"" space ":" space string
|
b-kv ::= "\"b\"" space ":" space string
|
||||||
c-kv ::= "\"c\"" space ":" space string
|
c-kv ::= "\"c\"" space ":" space string
|
||||||
c-rest ::= ( "," space d-kv )?
|
|
||||||
d-kv ::= "\"d\"" space ":" space string
|
d-kv ::= "\"d\"" space ":" space string
|
||||||
root ::= "{" space a-kv "," space b-kv ( "," space ( c-kv c-rest | d-kv ) )? "}" space
|
d-rest ::= ( "," space c-kv )?
|
||||||
|
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||||
space ::= " "?
|
space ::= " "?
|
||||||
string ::= "\"" (
|
string ::= "\"" (
|
||||||
[^"\\] |
|
[^"\\] |
|
||||||
@ -796,7 +801,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
int main() {
|
int main() {
|
||||||
test_all("C++", [](const TestCase & tc) {
|
test_all("C++", [](const TestCase & tc) {
|
||||||
try {
|
try {
|
||||||
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
|
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
|
||||||
tc.verify_status(SUCCESS);
|
tc.verify_status(SUCCESS);
|
||||||
} catch (const std::runtime_error & ex) {
|
} catch (const std::runtime_error & ex) {
|
||||||
fprintf(stderr, "Error: %s\n", ex.what());
|
fprintf(stderr, "Error: %s\n", ex.what());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user