text-generation-webui/modules/grammar.py

34 lines
869 B
Python
Raw Permalink Normal View History

from torch_grammar import GrammarSampler
from transformers.generation.logits_process import LogitsProcessor
from modules import shared
sampler = None
grammar = None
grammar_string = ''
class GrammarLogitsProcessor(LogitsProcessor):
def __init__(self, string):
global sampler, grammar, grammar_string
if string != grammar_string:
grammar_string = string
if string.strip() != '':
string = string.strip() + '\n'
sampler = GrammarSampler(string, 'root', shared.tokenizer)
else:
sampler = None
if sampler is not None:
grammar = sampler.logits_processor()
else:
grammar = None
def __call__(self, input_ids, scores):
if grammar is not None:
scores = grammar(input_ids, scores)
return scores