mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add grammar to transformers and _HF loaders (#4091)
This commit is contained in:
parent
0197fdddf1
commit
ae4ba3007f
33
modules/grammar.py
Normal file
33
modules/grammar.py
Normal file
@ -0,0 +1,33 @@
|
||||
from torch_grammar import GrammarSampler
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from modules import shared
|
||||
|
||||
sampler = None
|
||||
grammar = None
|
||||
grammar_string = ''
|
||||
|
||||
|
||||
class GrammarLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, string):
|
||||
|
||||
global sampler, grammar, grammar_string
|
||||
|
||||
if string != grammar_string:
|
||||
grammar_string = string
|
||||
if string.strip() != '':
|
||||
string = string.strip() + '\n'
|
||||
sampler = GrammarSampler(string, 'root', shared.tokenizer)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if sampler is not None:
|
||||
grammar = sampler.logits_processor()
|
||||
else:
|
||||
grammar = None
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if grammar is not None:
|
||||
scores = grammar(input_ids, scores)
|
||||
|
||||
return scores
|
@ -156,6 +156,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
@ -183,6 +185,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
@ -236,6 +240,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
@ -267,6 +273,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
@ -298,6 +306,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
@ -339,6 +349,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
|
@ -18,6 +18,7 @@ from modules.callbacks import (
|
||||
_StopEverythingStoppingCriteria
|
||||
)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.grammar import GrammarLogitsProcessor
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache, local_rank
|
||||
@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
# In case a processor is passed by itself.
|
||||
if not isinstance(processor, LogitsProcessorList):
|
||||
processor = LogitsProcessorList([processor])
|
||||
processor.append(GrammarLogitsProcessor(state['grammar_string']))
|
||||
apply_extensions('logits_processor', processor, input_ids)
|
||||
generate_params['logits_processor'] = processor
|
||||
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
@ -25,6 +25,7 @@ tqdm
|
||||
wandb
|
||||
|
||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||
git+https://github.com/oobabooga/torch-grammar.git
|
||||
|
||||
# bitsandbytes
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
|
Loading…
Reference in New Issue
Block a user