mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add grammar to llama.cpp loader (closes #4019)
This commit is contained in:
parent
a3ad9fe6c0
commit
b227e65d86
@ -63,6 +63,7 @@ async def run(user_input, history):
|
|||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
|
'grammar_file': '',
|
||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ def run(user_input, history):
|
|||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
|
'grammar_file': '',
|
||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ async def run(context):
|
|||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
|
'grammar_file': '',
|
||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ def run(prompt):
|
|||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
|
'grammar_file': '',
|
||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ def build_parameters(body, chat=False):
|
|||||||
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
||||||
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
||||||
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
||||||
|
'grammar_file': str(body.get('grammar_file', '')),
|
||||||
'guidance_scale': float(body.get('guidance_scale', 1)),
|
'guidance_scale': float(body.get('guidance_scale', 1)),
|
||||||
'negative_prompt': str(body.get('negative_prompt', '')),
|
'negative_prompt': str(body.get('negative_prompt', '')),
|
||||||
'seed': int(body.get('seed', -1)),
|
'seed': int(body.get('seed', -1)),
|
||||||
|
@ -34,6 +34,7 @@ default_req_params = {
|
|||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5.0,
|
'mirostat_tau': 5.0,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
|
'grammar_file': '',
|
||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
6
grammars/arithmetic.gbnf
Normal file
6
grammars/arithmetic.gbnf
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
root ::= (expr "=" ws term "\n")+
|
||||||
|
expr ::= term ([-+*/] term)*
|
||||||
|
term ::= ident | num | "(" ws expr ")" ws
|
||||||
|
ident ::= [a-z] [a-z0-9_]* ws
|
||||||
|
num ::= [0-9]+ ws
|
||||||
|
ws ::= [ \t\n]*
|
42
grammars/c.gbnf
Normal file
42
grammars/c.gbnf
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
root ::= (declaration)*
|
||||||
|
|
||||||
|
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
|
||||||
|
|
||||||
|
dataType ::= "int" ws | "float" ws | "char" ws
|
||||||
|
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||||
|
|
||||||
|
parameter ::= dataType identifier
|
||||||
|
|
||||||
|
statement ::=
|
||||||
|
( dataType identifier ws "=" ws expression ";" ) |
|
||||||
|
( identifier ws "=" ws expression ";" ) |
|
||||||
|
( identifier ws "(" argList? ")" ";" ) |
|
||||||
|
( "return" ws expression ";" ) |
|
||||||
|
( "while" "(" condition ")" "{" statement* "}" ) |
|
||||||
|
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
|
||||||
|
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
|
||||||
|
( singleLineComment ) |
|
||||||
|
( multiLineComment )
|
||||||
|
|
||||||
|
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
|
||||||
|
forUpdate ::= identifier ws "=" ws expression
|
||||||
|
|
||||||
|
condition ::= expression relationOperator expression
|
||||||
|
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
|
||||||
|
|
||||||
|
expression ::= term (("+" | "-") term)*
|
||||||
|
term ::= factor(("*" | "/") factor)*
|
||||||
|
|
||||||
|
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
|
||||||
|
unaryTerm ::= "-" factor
|
||||||
|
funcCall ::= identifier "(" argList? ")"
|
||||||
|
parenExpression ::= "(" ws expression ws ")"
|
||||||
|
|
||||||
|
argList ::= expression ("," ws expression)*
|
||||||
|
|
||||||
|
number ::= [0-9]+
|
||||||
|
|
||||||
|
singleLineComment ::= "//" [^\n]* "\n"
|
||||||
|
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
|
||||||
|
|
||||||
|
ws ::= ([ \t\n]+)
|
13
grammars/chess.gbnf
Normal file
13
grammars/chess.gbnf
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Specifies chess moves as a list in algebraic notation, using PGN conventions
|
||||||
|
|
||||||
|
# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern
|
||||||
|
root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+
|
||||||
|
move ::= (pawn | nonpawn | castle) [+#]?
|
||||||
|
|
||||||
|
# piece type, optional file/rank, optional capture, dest file & rank
|
||||||
|
nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8]
|
||||||
|
|
||||||
|
# optional file & capture, dest file & rank, optional promotion
|
||||||
|
pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])?
|
||||||
|
|
||||||
|
castle ::= "O-O" "-O"?
|
7
grammars/japanese.gbnf
Normal file
7
grammars/japanese.gbnf
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# A probably incorrect grammar for Japanese
|
||||||
|
root ::= jp-char+ ([ \t\n] jp-char+)*
|
||||||
|
jp-char ::= hiragana | katakana | punctuation | cjk
|
||||||
|
hiragana ::= [ぁ-ゟ]
|
||||||
|
katakana ::= [ァ-ヿ]
|
||||||
|
punctuation ::= [、-〾]
|
||||||
|
cjk ::= [一-鿿]
|
25
grammars/json.gbnf
Normal file
25
grammars/json.gbnf
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
34
grammars/json_arr.gbnf
Normal file
34
grammars/json_arr.gbnf
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
||||||
|
# Useful for generating JSON arrays
|
||||||
|
|
||||||
|
root ::= arr
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
arr ::=
|
||||||
|
"[\n" ws (
|
||||||
|
value
|
||||||
|
(",\n" ws value)*
|
||||||
|
)? "]"
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
4
grammars/list.gbnf
Normal file
4
grammars/list.gbnf
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
root ::= item+
|
||||||
|
|
||||||
|
# Excludes various line break characters
|
||||||
|
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -42,6 +43,8 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
|||||||
class LlamaCppModel:
|
class LlamaCppModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
self.grammar_file = 'None'
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.model.__del__()
|
self.model.__del__()
|
||||||
@ -107,6 +110,17 @@ class LlamaCppModel:
|
|||||||
logits = np.expand_dims(logits, 0) # batch dim is expected
|
logits = np.expand_dims(logits, 0) # batch dim is expected
|
||||||
return torch.tensor(logits, dtype=torch.float32)
|
return torch.tensor(logits, dtype=torch.float32)
|
||||||
|
|
||||||
|
def load_grammar(self, fname):
|
||||||
|
if fname != self.grammar_file:
|
||||||
|
self.grammar_file = fname
|
||||||
|
p = Path(f'grammars/{fname}')
|
||||||
|
print(p)
|
||||||
|
if p.exists():
|
||||||
|
logger.info(f'Loading the following grammar file: {p}')
|
||||||
|
self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p))
|
||||||
|
else:
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
def generate(self, prompt, state, callback=None):
|
def generate(self, prompt, state, callback=None):
|
||||||
|
|
||||||
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
||||||
@ -118,6 +132,7 @@ class LlamaCppModel:
|
|||||||
prompt = prompt[-get_max_prompt_length(state):]
|
prompt = prompt[-get_max_prompt_length(state):]
|
||||||
prompt = self.decode(prompt)
|
prompt = self.decode(prompt)
|
||||||
|
|
||||||
|
self.load_grammar(state['grammar_file'])
|
||||||
logit_processors = LogitsProcessorList()
|
logit_processors = LogitsProcessorList()
|
||||||
if state['ban_eos_token']:
|
if state['ban_eos_token']:
|
||||||
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
|
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
|
||||||
@ -140,6 +155,7 @@ class LlamaCppModel:
|
|||||||
mirostat_eta=state['mirostat_eta'],
|
mirostat_eta=state['mirostat_eta'],
|
||||||
stream=True,
|
stream=True,
|
||||||
logits_processor=logit_processors,
|
logits_processor=logit_processors,
|
||||||
|
grammar=self.grammar
|
||||||
)
|
)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
|
@ -305,6 +305,7 @@ loaders_samplers = {
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'custom_token_bans',
|
'custom_token_bans',
|
||||||
},
|
},
|
||||||
|
@ -114,6 +114,7 @@ def list_interface_input_elements():
|
|||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'grammar_file',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
|
@ -108,6 +108,9 @@ def create_ui(default_preset):
|
|||||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Grammar file (GBNF)', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -124,3 +124,7 @@ def get_datasets(path: str, ext: str):
|
|||||||
|
|
||||||
def get_available_chat_styles():
|
def get_available_chat_styles():
|
||||||
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
|
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_grammars():
|
||||||
|
return ['None'] + sorted([item.name for item in list(Path('grammars').glob('*.gbnf'))], key=natural_keys)
|
||||||
|
Loading…
Reference in New Issue
Block a user