mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Better HF grammar implementation (#4953)
This commit is contained in:
parent
aa200f8723
commit
12690d3ffc
@ -1,7 +0,0 @@
|
|||||||
# A probably incorrect grammar for Japanese
|
|
||||||
root ::= jp-char+ ([ \t\n] jp-char+)*
|
|
||||||
jp-char ::= hiragana | katakana | punctuation | cjk
|
|
||||||
hiragana ::= [ぁ-ゟ]
|
|
||||||
katakana ::= [ァ-ヿ]
|
|
||||||
punctuation ::= [、-〾]
|
|
||||||
cjk ::= [一-鿿]
|
|
@ -1,25 +1,14 @@
|
|||||||
root ::= object
|
root ::= object
|
||||||
|
|
||||||
|
object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}"
|
||||||
|
|
||||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
object ::=
|
array ::= "[" ws ( value ("," ws value)* )? "]" ws
|
||||||
"{" ws (
|
|
||||||
string ":" ws value
|
|
||||||
("," ws string ":" ws value)*
|
|
||||||
)? "}" ws
|
|
||||||
|
|
||||||
array ::=
|
string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws
|
||||||
"[" ws (
|
|
||||||
value
|
|
||||||
("," ws value)*
|
|
||||||
)? "]" ws
|
|
||||||
|
|
||||||
string ::=
|
|
||||||
"\"" (
|
|
||||||
[^"\\] |
|
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
|
||||||
)* "\"" ws
|
|
||||||
|
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
|
||||||
ws ::= ([ \t\n] ws)?
|
ws ::= ([ \t\n] ws)?
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
|
||||||
# Useful for generating JSON arrays
|
|
||||||
|
|
||||||
root ::= arr
|
|
||||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
|
||||||
|
|
||||||
arr ::=
|
|
||||||
"[\n" ws (
|
|
||||||
value
|
|
||||||
(",\n" ws value)*
|
|
||||||
)? "]"
|
|
||||||
|
|
||||||
object ::=
|
|
||||||
"{" ws (
|
|
||||||
string ":" ws value
|
|
||||||
("," ws string ":" ws value)*
|
|
||||||
)? "}" ws
|
|
||||||
|
|
||||||
array ::=
|
|
||||||
"[" ws (
|
|
||||||
value
|
|
||||||
("," ws value)*
|
|
||||||
)? "]" ws
|
|
||||||
|
|
||||||
string ::=
|
|
||||||
"\"" (
|
|
||||||
[^"\\] |
|
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
|
||||||
)* "\"" ws
|
|
||||||
|
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
|
||||||
|
|
||||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
|
||||||
ws ::= ([ \t\n] ws)?
|
|
14
grammars/json_w_trailing_space.gbnf
Normal file
14
grammars/json_w_trailing_space.gbnf
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
root ::= object
|
||||||
|
|
||||||
|
object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws
|
||||||
|
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
array ::= "[" ws ( value ("," ws value)* )? "]" ws
|
||||||
|
|
||||||
|
string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
|
||||||
|
ws ::= ([ \t\n] ws)?
|
@ -1,4 +1,2 @@
|
|||||||
root ::= item+
|
root ::= "1. " paragraph "\n" ([0-9] [0-9]? ". " paragraph "\n")+
|
||||||
|
paragraph ::= [a-zA-Z'.,; ]+
|
||||||
# Excludes various line break characters
|
|
||||||
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
|
7
grammars/simple_arithmetic.gbnf
Normal file
7
grammars/simple_arithmetic.gbnf
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
root ::= (expr "=" ws term "\n")+
|
||||||
|
expr ::= term ([-+*/] term)*
|
||||||
|
term ::= num | "(" ws expr ")" ws
|
||||||
|
num ::= [0-9]+ ws
|
||||||
|
ws ::= [ \t\n]*
|
||||||
|
# this is a comment
|
||||||
|
|
@ -1,33 +0,0 @@
|
|||||||
from torch_grammar import GrammarSampler
|
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
sampler = None
|
|
||||||
grammar = None
|
|
||||||
grammar_string = ''
|
|
||||||
|
|
||||||
|
|
||||||
class GrammarLogitsProcessor(LogitsProcessor):
|
|
||||||
def __init__(self, string):
|
|
||||||
|
|
||||||
global sampler, grammar, grammar_string
|
|
||||||
|
|
||||||
if string != grammar_string:
|
|
||||||
grammar_string = string
|
|
||||||
if string.strip() != '':
|
|
||||||
string = string.strip() + '\n'
|
|
||||||
sampler = GrammarSampler(string, 'root', shared.tokenizer)
|
|
||||||
else:
|
|
||||||
sampler = None
|
|
||||||
|
|
||||||
if sampler is not None:
|
|
||||||
grammar = sampler.logits_processor()
|
|
||||||
else:
|
|
||||||
grammar = None
|
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
|
||||||
if grammar is not None:
|
|
||||||
scores = grammar(input_ids, scores)
|
|
||||||
|
|
||||||
return scores
|
|
687
modules/grammar/grammar_utils.py
Normal file
687
modules/grammar/grammar_utils.py
Normal file
@ -0,0 +1,687 @@
|
|||||||
|
'''
|
||||||
|
This file has been 100% copied from this PR to the Transformers library:
|
||||||
|
https://github.com/huggingface/transformers/pull/27557
|
||||||
|
|
||||||
|
Author: Saibo-creator
|
||||||
|
Author GitHub: https://github.com/Saibo-creator
|
||||||
|
|
||||||
|
All credits go to the author.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from abc import ABC
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
########################
|
||||||
|
# EBNF Grammar Parsing #
|
||||||
|
########################
|
||||||
|
|
||||||
|
END_OF_ALTERNATE_MARKER = 0
|
||||||
|
END_OF_RULE_MARKER = 0
|
||||||
|
TO_BE_FILLED_MARKER = 0
|
||||||
|
REF_RULE_MARKER = 1
|
||||||
|
LITERAL_MARKER = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ParseState:
|
||||||
|
def __init__(self):
|
||||||
|
self.symbol_ids = {}
|
||||||
|
self.grammar_encoding = [] # old name: out_grammar
|
||||||
|
|
||||||
|
|
||||||
|
def get_symbol_id(state, src):
|
||||||
|
if src not in state.symbol_ids:
|
||||||
|
state.symbol_ids[src] = len(state.symbol_ids)
|
||||||
|
return state.symbol_ids[src]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_symbol_id(state, base_name):
|
||||||
|
next_id = len(state.symbol_ids)
|
||||||
|
state.symbol_ids[base_name + "_" + str(next_id)] = next_id
|
||||||
|
return next_id
|
||||||
|
|
||||||
|
|
||||||
|
def is_word_char(c):
|
||||||
|
return c.isalnum() or c == "-" or c == "_"
|
||||||
|
|
||||||
|
|
||||||
|
def hex_to_int(c):
|
||||||
|
if c.isdigit():
|
||||||
|
return int(c)
|
||||||
|
elif "a" <= c.lower() <= "f":
|
||||||
|
return ord(c.lower()) - ord("a") + 10
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def remove_leading_white_space(src, newline_ok):
|
||||||
|
"""
|
||||||
|
Skips over whitespace and comments in the input string.
|
||||||
|
This function processes the input string, skipping over any spaces, tabs,
|
||||||
|
and content following a '#' character, which denotes a comment. The parsing
|
||||||
|
of a comment continues until the end of the line (denoted by newline characters
|
||||||
|
'\r' or '\n'). If the 'newline_ok' parameter is set to False, the function
|
||||||
|
will stop processing and return the remaining string upon encountering a
|
||||||
|
newline character, otherwise it will skip over newline characters as well.
|
||||||
|
Parameters:
|
||||||
|
src (str): The input string to be processed.
|
||||||
|
newline_ok (bool): A flag indicating whether encountering a newline character
|
||||||
|
should stop the parsing (False) or if it should be skipped (True).
|
||||||
|
Returns:
|
||||||
|
str: The remaining portion of the input string after skipping whitespace and comments.
|
||||||
|
"""
|
||||||
|
pos = 0
|
||||||
|
while pos < len(src) and (src[pos].isspace() or src[pos] == "#"):
|
||||||
|
if src[pos] == "#":
|
||||||
|
while pos < len(src) and src[pos] not in ("\r", "\n"):
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
if not newline_ok and src[pos] in ("\r", "\n"):
|
||||||
|
break
|
||||||
|
pos += 1
|
||||||
|
return src[pos:]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_name(src):
|
||||||
|
pos = 0
|
||||||
|
while pos < len(src) and is_word_char(src[pos]):
|
||||||
|
pos += 1
|
||||||
|
if pos == 0:
|
||||||
|
raise RuntimeError("expecting name at " + src)
|
||||||
|
return src[:pos], src[pos:]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_char(src):
|
||||||
|
"""
|
||||||
|
parse the leading char from the input string
|
||||||
|
:param src:
|
||||||
|
:return: char, remaining_src
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if we have a backslash, it's maybe an escape
|
||||||
|
if src[0] == "\\":
|
||||||
|
esc = src[1]
|
||||||
|
if esc == "x":
|
||||||
|
first = hex_to_int(src[2])
|
||||||
|
if first > -1:
|
||||||
|
second = hex_to_int(src[3])
|
||||||
|
if second > -1:
|
||||||
|
return (first << 4) + second, src[4:]
|
||||||
|
raise RuntimeError("expecting \\xNN at " + src)
|
||||||
|
elif esc in ('"', "[", "]"):
|
||||||
|
return esc, src[2:]
|
||||||
|
elif esc == "r":
|
||||||
|
return "\r", src[2:]
|
||||||
|
elif esc == "n":
|
||||||
|
return "\n", src[2:]
|
||||||
|
elif esc == "t":
|
||||||
|
return "\t", src[2:]
|
||||||
|
raise RuntimeError("unknown escape at " + src)
|
||||||
|
elif src:
|
||||||
|
return src[0], src[1:]
|
||||||
|
raise RuntimeError("unexpected end of input")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_sequence(state, src, rule_name, outbuf, is_nested):
|
||||||
|
out_start_pos = len(outbuf)
|
||||||
|
|
||||||
|
# sequence size, will be replaced at end when known
|
||||||
|
outbuf.append(TO_BE_FILLED_MARKER)
|
||||||
|
|
||||||
|
last_sym_start = len(outbuf)
|
||||||
|
remaining_src = src
|
||||||
|
while remaining_src:
|
||||||
|
if remaining_src[0] == '"': # literal string
|
||||||
|
remaining_src = remaining_src[1:]
|
||||||
|
last_sym_start = len(outbuf)
|
||||||
|
while remaining_src[0] != '"':
|
||||||
|
char, remaining_src = parse_char(remaining_src)
|
||||||
|
|
||||||
|
# each char of a literal is encoded as a "range" of char - char
|
||||||
|
outbuf.append(LITERAL_MARKER)
|
||||||
|
outbuf.append(ord(char))
|
||||||
|
outbuf.append(ord(char))
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested)
|
||||||
|
elif remaining_src[0] == "[": # char range(s)
|
||||||
|
remaining_src = remaining_src[1:]
|
||||||
|
last_sym_start = len(outbuf)
|
||||||
|
# num chars in range - replaced at end of loop
|
||||||
|
outbuf.append(TO_BE_FILLED_MARKER)
|
||||||
|
while remaining_src[0] != "]":
|
||||||
|
char, remaining_src = parse_char(remaining_src)
|
||||||
|
|
||||||
|
outbuf.append(ord(char))
|
||||||
|
if remaining_src[0] == "-" and remaining_src[1] != "]":
|
||||||
|
endchar_pair, remaining_src = parse_char(remaining_src[1:])
|
||||||
|
outbuf.append(ord(endchar_pair))
|
||||||
|
else:
|
||||||
|
# chars that aren't part of a c1-c2 range are just doubled (i.e., c-c)
|
||||||
|
outbuf.append(ord(char))
|
||||||
|
# replace num chars with actual
|
||||||
|
outbuf[last_sym_start] = len(outbuf) - last_sym_start - 1
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested)
|
||||||
|
elif is_word_char(remaining_src[0]): # rule reference
|
||||||
|
name, remaining_src = parse_name(remaining_src)
|
||||||
|
ref_rule_id = get_symbol_id(state, name)
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src, is_nested)
|
||||||
|
last_sym_start = len(outbuf)
|
||||||
|
outbuf.append(REF_RULE_MARKER)
|
||||||
|
outbuf.append(ref_rule_id)
|
||||||
|
elif remaining_src[0] == "(": # grouping
|
||||||
|
# parse nested alternates into synthesized rule
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], True)
|
||||||
|
sub_rule_id = generate_symbol_id(state, rule_name)
|
||||||
|
remaining_src = parse_alternates(state, remaining_src, rule_name, sub_rule_id, True)
|
||||||
|
last_sym_start = len(outbuf)
|
||||||
|
# output reference to synthesized rule
|
||||||
|
outbuf.append(REF_RULE_MARKER)
|
||||||
|
outbuf.append(sub_rule_id)
|
||||||
|
if remaining_src[0] != ")":
|
||||||
|
raise RuntimeError("expecting ')' at " + remaining_src)
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested)
|
||||||
|
elif remaining_src[0] in ("*", "+", "?"): # repetition operator
|
||||||
|
if len(outbuf) - out_start_pos - 1 == 0:
|
||||||
|
raise RuntimeError("expecting preceeding item to */+/? at " + remaining_src)
|
||||||
|
out_grammar = state.grammar_encoding
|
||||||
|
|
||||||
|
# apply transformation to previous symbol (last_sym_start -
|
||||||
|
# end) according to rewrite rules:
|
||||||
|
# S* --> S' ::= S S' |
|
||||||
|
# S+ --> S' ::= S S' | S
|
||||||
|
# S? --> S' ::= S |
|
||||||
|
sub_rule_id = generate_symbol_id(state, rule_name)
|
||||||
|
out_grammar.append(sub_rule_id)
|
||||||
|
sub_rule_start = len(out_grammar)
|
||||||
|
# placeholder for size of 1st alternate
|
||||||
|
out_grammar.append(TO_BE_FILLED_MARKER)
|
||||||
|
# add preceding symbol to generated rule
|
||||||
|
out_grammar.extend(outbuf[last_sym_start:])
|
||||||
|
if remaining_src[0] in ("*", "+"):
|
||||||
|
# cause generated rule to recurse
|
||||||
|
out_grammar.append(REF_RULE_MARKER)
|
||||||
|
out_grammar.append(sub_rule_id)
|
||||||
|
# apply actual size
|
||||||
|
out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start
|
||||||
|
# mark end of 1st alternate
|
||||||
|
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
||||||
|
sub_rule_start = len(out_grammar)
|
||||||
|
# placeholder for size of 2nd alternate
|
||||||
|
out_grammar.append(TO_BE_FILLED_MARKER)
|
||||||
|
if remaining_src[0] == "+":
|
||||||
|
# add preceding symbol as alternate only for '+'
|
||||||
|
out_grammar.extend(outbuf[last_sym_start:])
|
||||||
|
# apply actual size of 2nd alternate
|
||||||
|
out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start
|
||||||
|
# mark end of 2nd alternate, then end of rule
|
||||||
|
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
||||||
|
out_grammar.append(END_OF_RULE_MARKER)
|
||||||
|
|
||||||
|
# in original rule, replace previous symbol with reference to generated rule
|
||||||
|
outbuf[last_sym_start:] = [1, sub_rule_id]
|
||||||
|
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
# apply actual size of this alternate sequence
|
||||||
|
outbuf[out_start_pos] = len(outbuf) - out_start_pos
|
||||||
|
# mark end of alternate
|
||||||
|
outbuf.append(END_OF_ALTERNATE_MARKER)
|
||||||
|
return remaining_src
|
||||||
|
|
||||||
|
|
||||||
|
def parse_alternates(state, src, rule_name, rule_id, is_nested):
|
||||||
|
outbuf = []
|
||||||
|
remaining_src = parse_sequence(state, src, rule_name, outbuf, is_nested)
|
||||||
|
while remaining_src and remaining_src[0] == "|":
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[1:], True)
|
||||||
|
remaining_src = parse_sequence(state, remaining_src, rule_name, outbuf, is_nested)
|
||||||
|
|
||||||
|
state.grammar_encoding.append(rule_id)
|
||||||
|
state.grammar_encoding.extend(outbuf)
|
||||||
|
state.grammar_encoding.append(0)
|
||||||
|
return remaining_src
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rule(state, src):
|
||||||
|
name, remaining_src = parse_name(src)
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src, False)
|
||||||
|
rule_id = get_symbol_id(state, name)
|
||||||
|
|
||||||
|
if remaining_src[:3] != "::=":
|
||||||
|
raise RuntimeError("expecting ::= at " + remaining_src)
|
||||||
|
remaining_src = remove_leading_white_space(remaining_src[3:], True)
|
||||||
|
|
||||||
|
remaining_src = parse_alternates(state, remaining_src, name, rule_id, False)
|
||||||
|
|
||||||
|
if remaining_src and remaining_src[0] == "\r":
|
||||||
|
remaining_src = remaining_src[2:] if remaining_src[1] == "\n" else remaining_src[1:]
|
||||||
|
elif remaining_src and remaining_src[0] == "\n":
|
||||||
|
remaining_src = remaining_src[1:]
|
||||||
|
elif remaining_src:
|
||||||
|
raise RuntimeError("expecting newline or end at " + remaining_src)
|
||||||
|
return remove_leading_white_space(remaining_src, True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ebnf(src):
|
||||||
|
try:
|
||||||
|
state = ParseState()
|
||||||
|
grammar_repr = remove_leading_white_space(src, True)
|
||||||
|
last_grammar_repr = ""
|
||||||
|
while grammar_repr:
|
||||||
|
if last_grammar_repr:
|
||||||
|
last_parsed_rule_len = len(last_grammar_repr) - len(grammar_repr)
|
||||||
|
logger.debug(f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}")
|
||||||
|
last_grammar_repr = grammar_repr
|
||||||
|
grammar_repr = parse_rule(state, grammar_repr)
|
||||||
|
state.grammar_encoding.append(0xFFFF)
|
||||||
|
return state
|
||||||
|
except RuntimeError as err:
|
||||||
|
logger.warning("error parsing grammar:", err)
|
||||||
|
return ParseState()
|
||||||
|
|
||||||
|
|
||||||
|
def print_rule(file, grammar_encoding, index, symbol_id_names):
|
||||||
|
rule_id = grammar_encoding[index]
|
||||||
|
print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file)
|
||||||
|
pos = index + 1
|
||||||
|
while grammar_encoding[pos]:
|
||||||
|
if pos - 1 > index:
|
||||||
|
print("|", end=" ", file=file)
|
||||||
|
pos += 1 # sequence size, not needed here
|
||||||
|
while grammar_encoding[pos]:
|
||||||
|
if grammar_encoding[pos] == REF_RULE_MARKER:
|
||||||
|
ref_rule_id = grammar_encoding[pos + 1]
|
||||||
|
print(
|
||||||
|
f"<{pos}>{symbol_id_names[ref_rule_id]}",
|
||||||
|
end=" ",
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
pos += 2
|
||||||
|
else:
|
||||||
|
print("<{}>[".format(pos), end="", file=file)
|
||||||
|
num_chars = grammar_encoding[pos]
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
for i in range(0, num_chars, 2):
|
||||||
|
print("{}-".format(chr(grammar_encoding[pos + i])), end="", file=file)
|
||||||
|
if i + 1 < num_chars:
|
||||||
|
print("{}".format(chr(grammar_encoding[pos + i + 1])), end="", file=file)
|
||||||
|
print("]", end=" ", file=file)
|
||||||
|
pos += num_chars
|
||||||
|
pos += 1
|
||||||
|
print(file=file)
|
||||||
|
return pos + 1
|
||||||
|
|
||||||
|
|
||||||
|
def print_grammar(file, state):
|
||||||
|
pos = 0
|
||||||
|
symbol_id_names = {v: k for k, v in state.symbol_ids.items()}
|
||||||
|
print("Grammar Rules:", file=file)
|
||||||
|
|
||||||
|
while state.grammar_encoding[pos] != 0xFFFF:
|
||||||
|
pos = print_rule(file, state.grammar_encoding, pos, symbol_id_names)
|
||||||
|
pos = 0
|
||||||
|
print("\nBinary representation:", file=file)
|
||||||
|
while state.grammar_encoding[pos] != 0xFFFF:
|
||||||
|
print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file)
|
||||||
|
pos += 1
|
||||||
|
print("ffff\n")
|
||||||
|
|
||||||
|
|
||||||
|
###################################
|
||||||
|
# EBNF Grammar Parsing ends here #
|
||||||
|
###################################
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarConstraint(ABC):
|
||||||
|
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
||||||
|
self.tt = 0
|
||||||
|
self.nt = 0
|
||||||
|
state = parse_ebnf(grammar_str)
|
||||||
|
grammar_encoding = state.grammar_encoding
|
||||||
|
self.start_rule_id = state.symbol_ids.get(start_rule_name)
|
||||||
|
|
||||||
|
self.eos_token_id = tokenizer.eos_token_id
|
||||||
|
self.token_trie = TokenTrie(tokenizer)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.grammar_encoding = grammar_encoding
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
rules: Dict[int, int] = {}
|
||||||
|
|
||||||
|
while grammar_encoding[pos] != 0xFFFF:
|
||||||
|
rule_id = grammar_encoding[pos]
|
||||||
|
|
||||||
|
# Store the current position in the 'rules' list at the index corresponding to rule_id.
|
||||||
|
# This effectively maps each rule_id to its position in the grammar encoding.
|
||||||
|
rules[rule_id] = pos
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
# Continue to the next rule in the encoding.
|
||||||
|
# The loop advances by the size indicated at the current position (grammar_encoding[pos])
|
||||||
|
# plus one for the size field itself.
|
||||||
|
while grammar_encoding[pos]:
|
||||||
|
pos += 1 + grammar_encoding[pos]
|
||||||
|
# Now we're at the end of the rule,
|
||||||
|
# so advance to the next rule by skipping the 0, which means 'end of rule'.
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
self.start_rule_pos = rules[self.start_rule_id]
|
||||||
|
self.rules_pos_dict: Dict[int, int] = rules
|
||||||
|
|
||||||
|
def init_stacks(self):
|
||||||
|
# suppose the start rule position is 0, then grammar_encoding[0] = rule_id
|
||||||
|
# grammar_encoding[1] = rule_size
|
||||||
|
# grammar_encoding[2] = rule_type
|
||||||
|
# this is why we need to add 2 to the start rule position
|
||||||
|
stack = [self.start_rule_pos + 2]
|
||||||
|
# convert to tuple for caching(immutable)
|
||||||
|
return self.advance_stack(tuple(stack))
|
||||||
|
|
||||||
|
# For each stack, resolve rules to find the actual characters that are
|
||||||
|
# accepted by this stack (not the set of sub-rules).
|
||||||
|
# This is where the parsing happens.
|
||||||
|
# The parsing is a top-down, left-to-right, depth-first traversal of the
|
||||||
|
# grammar.
|
||||||
|
@lru_cache(maxsize=32768)
|
||||||
|
def advance_stack(self, stack):
|
||||||
|
stack = list(stack)
|
||||||
|
# If the stack is empty, we're done. Because no more tokens should be accepted.
|
||||||
|
if len(stack) == 0:
|
||||||
|
return [stack]
|
||||||
|
|
||||||
|
# Get the top of the stack.
|
||||||
|
pos = stack[-1]
|
||||||
|
|
||||||
|
# If the stack head is a terminal(literal), we can resolve it immediately.
|
||||||
|
# literal is marked with 2 in the grammar encoding.
|
||||||
|
if self.grammar_encoding[pos] > 1:
|
||||||
|
return [stack]
|
||||||
|
|
||||||
|
# The stack head is a nonterminal (a rule reference, 1 in the grammar encoding).
|
||||||
|
# Resolving this rule gives a set of one or more possible positions
|
||||||
|
# (e.g. two in `a ::= b | c`)
|
||||||
|
# We pop the current rule off the stack and, for each option, push:
|
||||||
|
# - the symbol following this symbol in the current rule; then
|
||||||
|
# - the first symbol of the resolved rule.
|
||||||
|
referenced_rule_id = self.grammar_encoding[pos + 1]
|
||||||
|
|
||||||
|
# subpos should points to the size of the subrule
|
||||||
|
subpos = self.rules_pos_dict[referenced_rule_id] + 1
|
||||||
|
stacks: List[List[int]] = []
|
||||||
|
|
||||||
|
# do depth-first search to find all possible rules and check the next terminal
|
||||||
|
# When this value is non-zero, it indicates that subpos is not yet at the end of the rule, so we can continue.
|
||||||
|
# here subpos is a pointer, and the value in the rule encoding can never be 0 except for the end of the rule.
|
||||||
|
while self.grammar_encoding[subpos]:
|
||||||
|
new_stack = stack[:-1]
|
||||||
|
if self.grammar_encoding[pos + 2]:
|
||||||
|
# check if there is a next symbol in the current rule, e.g. `a ::= b c | d`
|
||||||
|
# if yes, push the pos to rule_size to the stack
|
||||||
|
new_stack.append(pos + 2)
|
||||||
|
|
||||||
|
# if the type of the next symbol is not "empty", push the first symbol of the resolved rule to the stack
|
||||||
|
if self.grammar_encoding[subpos + 1]:
|
||||||
|
new_stack.append(subpos + 1)
|
||||||
|
stacks.extend(self.advance_stack(tuple(new_stack)))
|
||||||
|
# The increment subpos += self.grammar_encoding[subpos] + 1
|
||||||
|
# moves subpos forward in the grammar encoding array to the next alternative in the current rule.
|
||||||
|
subpos += self.grammar_encoding[subpos] + 1
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
def accept_char(self, *args, **kwargs):
|
||||||
|
"""Process a byte according to the grammar rules."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def accept_token_id(self, *args, **kwargs):
|
||||||
|
"""Process a token according to the grammar rules."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def filter_vocab(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalGrammarConstraint(GrammarConstraint):
|
||||||
|
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
||||||
|
super().__init__(grammar_str, start_rule_name, tokenizer)
|
||||||
|
|
||||||
|
def accept_char(self, byte, stacks):
|
||||||
|
new_stacks = []
|
||||||
|
for stack in stacks:
|
||||||
|
# stack is empty
|
||||||
|
if not stack:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pos = stack[-1]
|
||||||
|
num_chars = self.grammar_encoding[pos]
|
||||||
|
|
||||||
|
# to make pos point to the size of the char range rule
|
||||||
|
pos += 1
|
||||||
|
found = False
|
||||||
|
for i in range(0, num_chars, 2):
|
||||||
|
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if not found:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pos += num_chars
|
||||||
|
new_stack = stack[:-1]
|
||||||
|
if self.grammar_encoding[pos]:
|
||||||
|
new_stack.append(pos)
|
||||||
|
new_stacks.extend(self.advance_stack(tuple(new_stack)))
|
||||||
|
|
||||||
|
return new_stacks
|
||||||
|
|
||||||
|
def accept_string(self, string: str, stacks: List[List[int]]):
|
||||||
|
_bytes = bytes(string, "utf-8")
|
||||||
|
for byte in _bytes:
|
||||||
|
stacks = self.accept_char(byte, stacks)
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
def accept_token_id(self, token_id: int, stacks: List[List[int]]):
|
||||||
|
if token_id == self.eos_token_id:
|
||||||
|
if stacks and all(len(stack) != 0 for stack in stacks):
|
||||||
|
raise Exception(
|
||||||
|
f"At least one of the stack should be empty when EOS is reached. However, "
|
||||||
|
f"the stacks are {stacks}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
for byte in self.token_trie.id2str(token_id):
|
||||||
|
stacks = self.accept_char(byte, stacks)
|
||||||
|
# check updated stacks
|
||||||
|
# TODO, I commented this out because it will fail when the stack is empty
|
||||||
|
# empty stack means the end of the grammar
|
||||||
|
# assert stacks != []
|
||||||
|
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
def accept_token_ids(self, token_ids: List[int], stacks: List[List[int]], as_string=True):
|
||||||
|
if as_string:
|
||||||
|
string = self.tokenizer.decode(token_ids)
|
||||||
|
stacks = self.accept_string(string, stacks)
|
||||||
|
else:
|
||||||
|
for token_id in token_ids:
|
||||||
|
stacks = self.accept_token_id(token_id, stacks)
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
def batch_filter_vocab(self, batch_stacks, device):
|
||||||
|
batch_acceptance = []
|
||||||
|
for stacks in batch_stacks:
|
||||||
|
batch_acceptance.append(self.filter_vocab(stacks, device))
|
||||||
|
return torch.stack(batch_acceptance)
|
||||||
|
|
||||||
|
def filter_vocab(self, stacks, device):
|
||||||
|
if not stacks: # Check if stacks is empty
|
||||||
|
# Handle the empty case: for example, return a tensor of False
|
||||||
|
# The size of the tensor should match the size of your vocabulary
|
||||||
|
vocab_size = len(self.token_trie)
|
||||||
|
logger.debug(f"sum of acceptance: {0}")
|
||||||
|
return torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
acceptance_matrix = torch.cat([self.token_acceptance_for_stack(tuple(stack), device) for stack in stacks])
|
||||||
|
# Merge stacks: any True => True
|
||||||
|
acceptance = acceptance_matrix.reshape(len(stacks), -1).any(dim=0)
|
||||||
|
logger.debug(f"sum of acceptance: {acceptance.sum()}")
|
||||||
|
return acceptance
|
||||||
|
|
||||||
|
# For each sub-rule in the grammar, cache whether each byte is accepted.
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def pos_char_acceptance(self, pos):
|
||||||
|
acceptance = [False] * 256
|
||||||
|
num_chars = self.grammar_encoding[pos]
|
||||||
|
pos += 1
|
||||||
|
for i in range(0, num_chars, 2):
|
||||||
|
start = self.grammar_encoding[pos + i]
|
||||||
|
end = self.grammar_encoding[pos + i + 1]
|
||||||
|
for j in range(start, end + 1):
|
||||||
|
acceptance[j] = True
|
||||||
|
return acceptance
|
||||||
|
|
||||||
|
# Probably this should be configurable. If the grammar has an exceedingly
|
||||||
|
# large number of states, the correct setting is a tradeoff between GPU
|
||||||
|
# RAM usage and recomputation time.
|
||||||
|
#
|
||||||
|
# The main variable that pushes usage up here is number of states in the
|
||||||
|
# grammar.
|
||||||
|
@lru_cache(maxsize=32768)
|
||||||
|
def token_acceptance_for_stack(self, stack, device):
|
||||||
|
st = time.time()
|
||||||
|
stack = list(stack) # needs to come in as a tuple for lru_cache
|
||||||
|
|
||||||
|
accepts = [False] * len(self.token_trie)
|
||||||
|
accepts[self.eos_token_id] = len(stack) == 0
|
||||||
|
if len(stack) == 0:
|
||||||
|
logger.debug("empty stack")
|
||||||
|
|
||||||
|
def traverse_trie(trie, stacks):
|
||||||
|
for byte, next_trie in trie.items():
|
||||||
|
if byte == LEAF:
|
||||||
|
token_id = next_trie
|
||||||
|
if token_id != self.eos_token_id:
|
||||||
|
accepts[token_id] = bool(stacks)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_stacks = []
|
||||||
|
for stk in stacks:
|
||||||
|
if not stk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pos = stk[-1]
|
||||||
|
num_chars = self.grammar_encoding[pos]
|
||||||
|
|
||||||
|
if not self.pos_char_acceptance(pos)[byte]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pos += num_chars + 1
|
||||||
|
new_stack = stk[:-1]
|
||||||
|
if self.grammar_encoding[pos]:
|
||||||
|
new_stack.append(pos)
|
||||||
|
new_stacks.extend(self.advance_stack(tuple(new_stack)))
|
||||||
|
|
||||||
|
if new_stacks:
|
||||||
|
traverse_trie(next_trie, new_stacks)
|
||||||
|
|
||||||
|
traverse_trie(self.token_trie.trie, [stack])
|
||||||
|
|
||||||
|
et = time.time() - st
|
||||||
|
x = torch.tensor(accepts, dtype=torch.bool, device=device)
|
||||||
|
self.tt += et
|
||||||
|
self.nt += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StaticGrammarConstraint(GrammarConstraint):
|
||||||
|
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
||||||
|
super().__init__(grammar_str, start_rule_name, tokenizer)
|
||||||
|
|
||||||
|
def accept_char(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
#################
|
||||||
|
# DATA STRUCTURES
|
||||||
|
#################
|
||||||
|
|
||||||
|
|
||||||
|
LEAF = -1
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTrie:
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
self.eos_token_id = tokenizer.eos_token_id
|
||||||
|
self.tokens = []
|
||||||
|
self.trie = {}
|
||||||
|
self.load_tokens(tokenizer)
|
||||||
|
|
||||||
|
def id2str(self, token_id):
|
||||||
|
return self.tokens[token_id]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tokens)
|
||||||
|
|
||||||
|
def load_tokens(self, tokenizer):
|
||||||
|
def replace_hex(match):
|
||||||
|
hex_value = match.group(1)
|
||||||
|
return chr(int(hex_value, 16))
|
||||||
|
|
||||||
|
if "gpt2" in tokenizer.__class__.__name__.lower():
|
||||||
|
special = tokenizer.additional_special_tokens_ids
|
||||||
|
|
||||||
|
# Here, the decoder does a string replace on a bunch of sequences
|
||||||
|
# like ' .' for '.'. This interferes with our assumptions, where a
|
||||||
|
# token should always have exactly one representation.
|
||||||
|
# Fortunately(?) text-generation-inference doesn't seem to run this
|
||||||
|
# cleanup, so we get extraneous spaces. So, in order to generate
|
||||||
|
# the right token set for TGI, we have to skip the space trimming.
|
||||||
|
# See:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600
|
||||||
|
def fmt_token(id):
|
||||||
|
if id in special:
|
||||||
|
return None
|
||||||
|
return bytes(tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8")
|
||||||
|
|
||||||
|
elif "llama" in tokenizer.__class__.__name__.lower():
|
||||||
|
|
||||||
|
def fmt_token(id):
|
||||||
|
token = tokenizer.convert_ids_to_tokens(id)
|
||||||
|
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
|
||||||
|
token = token.replace("▁", " ")
|
||||||
|
return bytes(token, "utf-8")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Warning: unrecognized tokenizer: using default token formatting")
|
||||||
|
|
||||||
|
def fmt_token(id):
|
||||||
|
token = tokenizer.convert_ids_to_tokens(id)
|
||||||
|
return bytes(token, "utf-8")
|
||||||
|
|
||||||
|
# note: vocab_size doesn't work here because there are also
|
||||||
|
# get_added_vocab() tokens
|
||||||
|
self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))]
|
||||||
|
for token_id, token_bytes in enumerate(self.tokens):
|
||||||
|
if token_bytes is not None:
|
||||||
|
self.insert_into_trie(self.trie, token_bytes, token_id)
|
||||||
|
|
||||||
|
def insert_into_trie(self, trie, token_bytes, token_id):
|
||||||
|
current = trie
|
||||||
|
for byte in token_bytes:
|
||||||
|
if byte not in current:
|
||||||
|
current[byte] = {}
|
||||||
|
current = current[byte]
|
||||||
|
current[LEAF] = token_id
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=5)
|
||||||
|
def initialize_grammar(grammar_string):
|
||||||
|
return IncrementalGrammarConstraint(grammar_string.strip(), start_rule_name="root", tokenizer=shared.tokenizer)
|
104
modules/grammar/logits_process.py
Normal file
104
modules/grammar/logits_process.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
'''
|
||||||
|
This file has been 100% copied from this PR to the Transformers library:
|
||||||
|
https://github.com/huggingface/transformers/pull/27557
|
||||||
|
|
||||||
|
Author: Saibo-creator
|
||||||
|
Author GitHub: https://github.com/Saibo-creator
|
||||||
|
|
||||||
|
All credits go to the author.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.utils import add_start_docstrings
|
||||||
|
|
||||||
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||||
|
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||||
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||||
|
search or log softmax for each vocabulary token when using beam search
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarConstrainedLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, grammar_constraint):
|
||||||
|
self.last_size = None
|
||||||
|
self.grammar_constraint = grammar_constraint
|
||||||
|
self.batch_stacks = None
|
||||||
|
|
||||||
|
def filter_logits(self, logits, device):
|
||||||
|
# resolve each stack to a tensor of True/False for each token
|
||||||
|
# indicating acceptance
|
||||||
|
# acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device)
|
||||||
|
acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device)
|
||||||
|
# logger.debug(acceptance)
|
||||||
|
# Logits to -inf where False
|
||||||
|
logits[~acceptance] = -math.inf
|
||||||
|
|
||||||
|
# TODO: batching
|
||||||
|
def process_logits(self, input_ids, scores, parse_start_index=None):
|
||||||
|
"""
|
||||||
|
:param input_ids:
|
||||||
|
:param scores:
|
||||||
|
:param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# we dynamically create stacks at the first call, so that we know the batch size and beam size
|
||||||
|
if self.batch_stacks is None:
|
||||||
|
self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))]
|
||||||
|
|
||||||
|
# if self.last_size is not set (which would be the case when processing the first token).
|
||||||
|
# In this case, do nothing.
|
||||||
|
if self.last_size is None:
|
||||||
|
prefix_to_parse = [
|
||||||
|
single_input_ids[parse_start_index:] if parse_start_index is not None else []
|
||||||
|
for single_input_ids in input_ids
|
||||||
|
]
|
||||||
|
# self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks)
|
||||||
|
self.batch_stacks = [
|
||||||
|
self.grammar_constraint.accept_token_ids(prefix, stack)
|
||||||
|
for prefix, stack in zip(prefix_to_parse, self.batch_stacks)
|
||||||
|
]
|
||||||
|
# if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size.
|
||||||
|
# This is expected in a scenario where inputs are processed incrementally, one token at a time.
|
||||||
|
elif len(input_ids[0]) == self.last_size + 1:
|
||||||
|
# self.stacks = self.grammar_acceptor.accept_token_id(input_ids[0][-1], self.stacks)
|
||||||
|
self.batch_stacks = [
|
||||||
|
self.grammar_constraint.accept_token_id(single_input_ids[-1], stack)
|
||||||
|
for single_input_ids, stack in zip(input_ids, self.batch_stacks)
|
||||||
|
]
|
||||||
|
# ensure that the input size is consistent with the expected incremental processing
|
||||||
|
# (i.e., one token at a time).
|
||||||
|
else:
|
||||||
|
# here we check if the input_ids are one token longer than the last time we processed
|
||||||
|
# but we don't check if input_ids are actually valid.
|
||||||
|
# Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens.
|
||||||
|
# In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid.
|
||||||
|
# However, should we really check if the input_ids are valid here?
|
||||||
|
# If we do, then we need to reparse the whole input_ids at each call, which is not efficient.
|
||||||
|
# Maybe we should just trust the user to provide valid input_ids?
|
||||||
|
# The conclusion is that, we assume the input_ids are valid, and our generation will be correct.
|
||||||
|
# If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Input ID's length is inconsistent with the current state of "
|
||||||
|
"the GrammarConstrainedLogitsProcessor. If you want to process "
|
||||||
|
"another input sequence, please instantiate a new "
|
||||||
|
"GrammarConstrainedLogitsProcessor."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.filter_logits(scores, scores.device)
|
||||||
|
|
||||||
|
self.last_size = len(input_ids[0])
|
||||||
|
return scores
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
return self.process_logits(input_ids, scores)
|
@ -18,7 +18,8 @@ from modules.callbacks import (
|
|||||||
_StopEverythingStoppingCriteria
|
_StopEverythingStoppingCriteria
|
||||||
)
|
)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.grammar import GrammarLogitsProcessor
|
from modules.grammar.grammar_utils import initialize_grammar
|
||||||
|
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache, local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
@ -317,11 +318,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||||
|
|
||||||
|
# Logits processor
|
||||||
processor = state.get('logits_processor', LogitsProcessorList([]))
|
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||||
# In case a processor is passed by itself.
|
|
||||||
if not isinstance(processor, LogitsProcessorList):
|
if not isinstance(processor, LogitsProcessorList):
|
||||||
processor = LogitsProcessorList([processor])
|
processor = LogitsProcessorList([processor])
|
||||||
processor.append(GrammarLogitsProcessor(state['grammar_string']))
|
|
||||||
|
# Grammar
|
||||||
|
if state['grammar_string'].strip() != '':
|
||||||
|
grammar = initialize_grammar(state['grammar_string'])
|
||||||
|
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||||
|
processor.append(grammar_processor)
|
||||||
|
|
||||||
apply_extensions('logits_processor', processor, input_ids)
|
apply_extensions('logits_processor', processor, input_ids)
|
||||||
generate_params['logits_processor'] = processor
|
generate_params['logits_processor'] = processor
|
||||||
|
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
@ -20,8 +20,6 @@ transformers==4.36.*
|
|||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/oobabooga/torch-grammar.git
|
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
Loading…
Reference in New Issue
Block a user