mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Add grammar to transformers and _HF loaders (#4091)
This commit is contained in:
parent
0197fdddf1
commit
ae4ba3007f
33
modules/grammar.py
Normal file
33
modules/grammar.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from torch_grammar import GrammarSampler
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
sampler = None
|
||||||
|
grammar = None
|
||||||
|
grammar_string = ''
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, string):
|
||||||
|
|
||||||
|
global sampler, grammar, grammar_string
|
||||||
|
|
||||||
|
if string != grammar_string:
|
||||||
|
grammar_string = string
|
||||||
|
if string.strip() != '':
|
||||||
|
string = string.strip() + '\n'
|
||||||
|
sampler = GrammarSampler(string, 'root', shared.tokenizer)
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
if sampler is not None:
|
||||||
|
grammar = sampler.logits_processor()
|
||||||
|
else:
|
||||||
|
grammar = None
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
if grammar is not None:
|
||||||
|
scores = grammar(input_ids, scores)
|
||||||
|
|
||||||
|
return scores
|
@ -156,6 +156,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
@ -183,6 +185,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
@ -236,6 +240,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
@ -267,6 +273,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
@ -298,6 +306,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
@ -339,6 +349,8 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file_row',
|
||||||
|
'grammar_string',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
@ -18,6 +18,7 @@ from modules.callbacks import (
|
|||||||
_StopEverythingStoppingCriteria
|
_StopEverythingStoppingCriteria
|
||||||
)
|
)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
|
from modules.grammar import GrammarLogitsProcessor
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache, local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
# In case a processor is passed by itself.
|
# In case a processor is passed by itself.
|
||||||
if not isinstance(processor, LogitsProcessorList):
|
if not isinstance(processor, LogitsProcessorList):
|
||||||
processor = LogitsProcessorList([processor])
|
processor = LogitsProcessorList([processor])
|
||||||
|
processor.append(GrammarLogitsProcessor(state['grammar_string']))
|
||||||
apply_extensions('logits_processor', processor, input_ids)
|
apply_extensions('logits_processor', processor, input_ids)
|
||||||
generate_params['logits_processor'] = processor
|
generate_params['logits_processor'] = processor
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
@ -25,6 +25,7 @@ tqdm
|
|||||||
wandb
|
wandb
|
||||||
|
|
||||||
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
|
||||||
|
git+https://github.com/oobabooga/torch-grammar.git
|
||||||
|
|
||||||
# bitsandbytes
|
# bitsandbytes
|
||||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||||
|
Loading…
Reference in New Issue
Block a user