mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
fix: grammar not support utf-8 (#5900)
This commit is contained in:
parent
8456d13349
commit
5cb59707f3
@ -60,7 +60,7 @@ def hex_to_int(c):
|
||||
return int(c)
|
||||
elif "a" <= c.lower() <= "f":
|
||||
return ord(c.lower()) - ord("a") + 10
|
||||
return -1
|
||||
raise RuntimeError("unknown hex char " + c)
|
||||
|
||||
|
||||
def remove_leading_white_space(src, newline_ok):
|
||||
@ -100,6 +100,13 @@ def parse_name(src):
|
||||
return src[:pos], src[pos:]
|
||||
|
||||
|
||||
def read_hex(s):
|
||||
val = 0
|
||||
for c in s:
|
||||
val = (val << 4) + hex_to_int(c)
|
||||
return chr(val)
|
||||
|
||||
|
||||
def parse_char(src):
|
||||
"""
|
||||
parse the leading char from the input string
|
||||
@ -111,13 +118,12 @@ def parse_char(src):
|
||||
if src[0] == "\\":
|
||||
esc = src[1]
|
||||
if esc == "x":
|
||||
first = hex_to_int(src[2])
|
||||
if first > -1:
|
||||
second = hex_to_int(src[3])
|
||||
if second > -1:
|
||||
return (first << 4) + second, src[4:]
|
||||
raise RuntimeError("expecting \\xNN at " + src)
|
||||
elif esc in ('"', "[", "]"):
|
||||
return read_hex(src[2:4]), src[4:]
|
||||
elif esc == "u":
|
||||
return read_hex(src[2:6]), src[6:]
|
||||
elif esc == "U":
|
||||
return read_hex(src[2:10]), src[10:]
|
||||
elif esc in ('"', "[", "]", "\\", "-"):
|
||||
return esc, src[2:]
|
||||
elif esc == "r":
|
||||
return "\r", src[2:]
|
||||
@ -454,7 +460,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
||||
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
||||
super().__init__(grammar_str, start_rule_name, tokenizer)
|
||||
|
||||
def accept_char(self, byte, stacks):
|
||||
def accept_char(self, char, stacks):
|
||||
byte = ord(char)
|
||||
new_stacks = []
|
||||
for stack in stacks:
|
||||
# stack is empty
|
||||
@ -471,6 +478,9 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
||||
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
|
||||
found = True
|
||||
break
|
||||
if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
continue
|
||||
|
||||
@ -483,9 +493,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
||||
return new_stacks
|
||||
|
||||
def accept_string(self, string: str, stacks: List[List[int]]):
|
||||
_bytes = bytes(string, "utf-8")
|
||||
for byte in _bytes:
|
||||
stacks = self.accept_char(byte, stacks)
|
||||
for char in string:
|
||||
stacks = self.accept_char(char, stacks)
|
||||
return stacks
|
||||
|
||||
def accept_token_id(self, token_id: int, stacks: List[List[int]]):
|
||||
@ -537,16 +546,18 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
||||
|
||||
# For each sub-rule in the grammar, cache whether each byte is accepted.
|
||||
@lru_cache(maxsize=None)
|
||||
def pos_char_acceptance(self, pos):
|
||||
acceptance = [False] * 256
|
||||
def pos_char_acceptance(self, pos, char):
|
||||
byte = ord(char)
|
||||
num_chars = self.grammar_encoding[pos]
|
||||
pos += 1
|
||||
for i in range(0, num_chars, 2):
|
||||
start = self.grammar_encoding[pos + i]
|
||||
end = self.grammar_encoding[pos + i + 1]
|
||||
for j in range(start, end + 1):
|
||||
acceptance[j] = True
|
||||
return acceptance
|
||||
if byte >= start and byte <= end:
|
||||
return True
|
||||
if byte <= start and byte >= end:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Probably this should be configurable. If the grammar has an exceedingly
|
||||
# large number of states, the correct setting is a tradeoff between GPU
|
||||
@ -580,7 +591,7 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
||||
pos = stk[-1]
|
||||
num_chars = self.grammar_encoding[pos]
|
||||
|
||||
if not self.pos_char_acceptance(pos)[byte]:
|
||||
if not self.pos_char_acceptance(pos, byte):
|
||||
continue
|
||||
|
||||
pos += num_chars + 1
|
||||
@ -657,14 +668,14 @@ class TokenTrie:
|
||||
token = tokenizer.convert_ids_to_tokens(id)
|
||||
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
|
||||
token = token.replace("▁", " ")
|
||||
return bytes(token, "utf-8")
|
||||
return token
|
||||
|
||||
else:
|
||||
print("Warning: unrecognized tokenizer: using default token formatting")
|
||||
|
||||
def fmt_token(id):
|
||||
token = tokenizer.convert_ids_to_tokens(id)
|
||||
return bytes(token, "utf-8")
|
||||
return token
|
||||
|
||||
# note: vocab_size doesn't work here because there are also
|
||||
# get_added_vocab() tokens
|
||||
|
Loading…
Reference in New Issue
Block a user