fix: grammar not support utf-8 (#5900)

This commit is contained in:
A0nameless0man 2024-05-20 07:10:39 +08:00 committed by GitHub
parent 8456d13349
commit 5cb59707f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,7 +60,7 @@ def hex_to_int(c):
return int(c) return int(c)
elif "a" <= c.lower() <= "f": elif "a" <= c.lower() <= "f":
return ord(c.lower()) - ord("a") + 10 return ord(c.lower()) - ord("a") + 10
return -1 raise RuntimeError("unknown hex char " + c)
def remove_leading_white_space(src, newline_ok): def remove_leading_white_space(src, newline_ok):
@ -100,6 +100,13 @@ def parse_name(src):
return src[:pos], src[pos:] return src[:pos], src[pos:]
def read_hex(s):
val = 0
for c in s:
val = (val << 4) + hex_to_int(c)
return chr(val)
def parse_char(src): def parse_char(src):
""" """
parse the leading char from the input string parse the leading char from the input string
@ -111,13 +118,12 @@ def parse_char(src):
if src[0] == "\\": if src[0] == "\\":
esc = src[1] esc = src[1]
if esc == "x": if esc == "x":
first = hex_to_int(src[2]) return read_hex(src[2:4]), src[4:]
if first > -1: elif esc == "u":
second = hex_to_int(src[3]) return read_hex(src[2:6]), src[6:]
if second > -1: elif esc == "U":
return (first << 4) + second, src[4:] return read_hex(src[2:10]), src[10:]
raise RuntimeError("expecting \\xNN at " + src) elif esc in ('"', "[", "]", "\\", "-"):
elif esc in ('"', "[", "]"):
return esc, src[2:] return esc, src[2:]
elif esc == "r": elif esc == "r":
return "\r", src[2:] return "\r", src[2:]
@ -454,7 +460,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
def __init__(self, grammar_str, start_rule_name, tokenizer): def __init__(self, grammar_str, start_rule_name, tokenizer):
super().__init__(grammar_str, start_rule_name, tokenizer) super().__init__(grammar_str, start_rule_name, tokenizer)
def accept_char(self, byte, stacks): def accept_char(self, char, stacks):
byte = ord(char)
new_stacks = [] new_stacks = []
for stack in stacks: for stack in stacks:
# stack is empty # stack is empty
@ -471,6 +478,9 @@ class IncrementalGrammarConstraint(GrammarConstraint):
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
found = True found = True
break break
if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]:
found = True
break
if not found: if not found:
continue continue
@ -483,9 +493,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
return new_stacks return new_stacks
def accept_string(self, string: str, stacks: List[List[int]]): def accept_string(self, string: str, stacks: List[List[int]]):
_bytes = bytes(string, "utf-8") for char in string:
for byte in _bytes: stacks = self.accept_char(char, stacks)
stacks = self.accept_char(byte, stacks)
return stacks return stacks
def accept_token_id(self, token_id: int, stacks: List[List[int]]): def accept_token_id(self, token_id: int, stacks: List[List[int]]):
@ -537,16 +546,18 @@ class IncrementalGrammarConstraint(GrammarConstraint):
# For each sub-rule in the grammar, cache whether each byte is accepted. # For each sub-rule in the grammar, cache whether each byte is accepted.
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def pos_char_acceptance(self, pos): def pos_char_acceptance(self, pos, char):
acceptance = [False] * 256 byte = ord(char)
num_chars = self.grammar_encoding[pos] num_chars = self.grammar_encoding[pos]
pos += 1 pos += 1
for i in range(0, num_chars, 2): for i in range(0, num_chars, 2):
start = self.grammar_encoding[pos + i] start = self.grammar_encoding[pos + i]
end = self.grammar_encoding[pos + i + 1] end = self.grammar_encoding[pos + i + 1]
for j in range(start, end + 1): if byte >= start and byte <= end:
acceptance[j] = True return True
return acceptance if byte <= start and byte >= end:
return True
return False
# Probably this should be configurable. If the grammar has an exceedingly # Probably this should be configurable. If the grammar has an exceedingly
# large number of states, the correct setting is a tradeoff between GPU # large number of states, the correct setting is a tradeoff between GPU
@ -580,7 +591,7 @@ class IncrementalGrammarConstraint(GrammarConstraint):
pos = stk[-1] pos = stk[-1]
num_chars = self.grammar_encoding[pos] num_chars = self.grammar_encoding[pos]
if not self.pos_char_acceptance(pos)[byte]: if not self.pos_char_acceptance(pos, byte):
continue continue
pos += num_chars + 1 pos += num_chars + 1
@ -657,14 +668,14 @@ class TokenTrie:
token = tokenizer.convert_ids_to_tokens(id) token = tokenizer.convert_ids_to_tokens(id)
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
token = token.replace("", " ") token = token.replace("", " ")
return bytes(token, "utf-8") return token
else: else:
print("Warning: unrecognized tokenizer: using default token formatting") print("Warning: unrecognized tokenizer: using default token formatting")
def fmt_token(id): def fmt_token(id):
token = tokenizer.convert_ids_to_tokens(id) token = tokenizer.convert_ids_to_tokens(id)
return bytes(token, "utf-8") return token
# note: vocab_size doesn't work here because there are also # note: vocab_size doesn't work here because there are also
# get_added_vocab() tokens # get_added_vocab() tokens