From 08cf150c0c452ea7fad48c0ef048dd8564742371 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 24 Sep 2023 18:05:24 -0300 Subject: [PATCH] Add a grammar editor to the UI (#4061) --- api-examples/api-example-chat-stream.py | 2 +- api-examples/api-example-chat.py | 2 +- api-examples/api-example-stream.py | 2 +- api-examples/api-example.py | 2 +- extensions/api/util.py | 2 +- extensions/openai/defaults.py | 2 +- modules/llamacpp_model.py | 16 ++--- modules/loaders.py | 1 + modules/ui.py | 2 +- modules/ui_file_saving.py | 12 ++++ modules/ui_parameters.py | 94 +++++++++++++------------ 11 files changed, 75 insertions(+), 62 deletions(-) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index abf05e11..bfa5d4f5 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -63,7 +63,7 @@ async def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, - 'grammar_file': '', + 'grammar_string': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index f53dbe4c..b2a1e1e4 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -57,7 +57,7 @@ def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, - 'grammar_file': '', + 'grammar_string': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index fbae4a6c..966ca6f6 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -46,7 +46,7 @@ async def run(context): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, - 'grammar_file': '', + 'grammar_string': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example.py b/api-examples/api-example.py index aae2a1ae..d9fd60d0 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -38,7 +38,7 @@ def run(prompt): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, - 'grammar_file': '', + 'grammar_string': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/extensions/api/util.py b/extensions/api/util.py index 53111141..2e42770d 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -44,7 +44,7 @@ def build_parameters(body, chat=False): 'mirostat_mode': int(body.get('mirostat_mode', 0)), 'mirostat_tau': float(body.get('mirostat_tau', 5)), 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), - 'grammar_file': str(body.get('grammar_file', '')), + 'grammar_string': str(body.get('grammar_string', '')), 'guidance_scale': float(body.get('guidance_scale', 1)), 'negative_prompt': str(body.get('negative_prompt', '')), 'seed': int(body.get('seed', -1)), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 4c1da893..2ebade82 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -34,7 +34,7 @@ default_req_params = { 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, - 'grammar_file': '', + 'grammar_string': '', 'guidance_scale': 1, 'negative_prompt': '', 'ban_eos_token': False, diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 83d02b7d..6b011e61 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -43,7 +43,7 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits): class LlamaCppModel: def __init__(self): self.initialized = False - self.grammar_file = 'None' + self.grammar_string = '' self.grammar = None def __del__(self): @@ -110,13 +110,11 @@ class LlamaCppModel: logits = np.expand_dims(logits, 0) # batch dim is expected return torch.tensor(logits, dtype=torch.float32) - def load_grammar(self, fname): - if fname != self.grammar_file: - self.grammar_file = fname - p = Path(f'grammars/{fname}') - if p.exists(): - logger.info(f'Loading the following grammar file: {p}') - self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p)) + def load_grammar(self, string): + if string != self.grammar_string: + self.grammar_string = string + if string.strip() != '': + self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string) else: self.grammar = None @@ -131,7 +129,7 @@ class LlamaCppModel: prompt = prompt[-get_max_prompt_length(state):] prompt = self.decode(prompt) - self.load_grammar(state['grammar_file']) + self.load_grammar(state['grammar_string']) logit_processors = LogitsProcessorList() if state['ban_eos_token']: logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos())) diff --git a/modules/loaders.py b/modules/loaders.py index 50ebb060..56c21c49 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -306,6 +306,7 @@ loaders_samplers = { 'mirostat_tau', 'mirostat_eta', 'grammar_file_row', + 'grammar_string', 'ban_eos_token', 'custom_token_bans', }, diff --git a/modules/ui.py b/modules/ui.py index d6da5cfc..1b5ab955 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -114,7 +114,7 @@ def list_interface_input_elements(): 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', - 'grammar_file', + 'grammar_string', 'negative_prompt', 'guidance_scale', 'add_bos_token', diff --git a/modules/ui_file_saving.py b/modules/ui_file_saving.py index d80378af..1ba8e4bb 100644 --- a/modules/ui_file_saving.py +++ b/modules/ui_file_saving.py @@ -73,3 +73,15 @@ def create_event_handlers(): lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then( lambda: 'presets/', None, gradio('delete_root')).then( lambda: gr.update(visible=True), None, gradio('file_deleter')) + + shared.gradio['save_grammar'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda x: x, gradio('grammar_string'), gradio('save_contents')).then( + lambda: 'grammars/', None, gradio('save_root')).then( + lambda: 'My Fancy Grammar.gbnf', None, gradio('save_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + shared.gradio['delete_grammar'].click( + lambda x: x, gradio('grammar_file'), gradio('delete_filename')).then( + lambda: 'grammars/', None, gradio('delete_root')).then( + lambda: gr.update(visible=True), None, gradio('file_deleter')) diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 7e7aa667..fcea8054 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -1,3 +1,5 @@ +from pathlib import Path + import gradio as gr from modules import loaders, presets, shared, ui, ui_chat, utils @@ -21,27 +23,36 @@ def create_ui(default_preset): with gr.Row(): with gr.Column(): - with gr.Box(): - with gr.Row(): - with gr.Column(): - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') - shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') - shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') - shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') - shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff') - shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff') - shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs') - shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a') + with gr.Row(): + with gr.Column(): + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') + shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') + shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') + shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') + shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') + shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') + shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs') + shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a') + shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff') + shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff') - with gr.Column(): - shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') - shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') + with gr.Column(): + shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') + shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', lines=3, elem_classes=['add_scrollbar']) + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.') + shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') + shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') + shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') + shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') + with gr.Accordion('Other parameters', open=False): shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length') - shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') - shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.') + shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') + shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') with gr.Accordion("Learn more", open=False): gr.Markdown(""" @@ -94,37 +105,28 @@ def create_ui(default_preset): """, elem_classes="markdown") with gr.Column(): - with gr.Box(): - with gr.Row(): - with gr.Column(): - shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') - shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', lines=3, elem_classes=['add_scrollbar']) - shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') - shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') - shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') + with gr.Row(): + with gr.Column(): + shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') + shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.') + shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') + shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in the Default or Notebook tab.') - with gr.Column(): - shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.') - shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.') - shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') - shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Row() as shared.gradio['grammar_file_row']: - shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Grammar file (GBNF)', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button') + with gr.Column(): + shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') + shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') + shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') - with gr.Box(): - with gr.Row(): - with gr.Column(): - shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') - shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.') - shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') - with gr.Column(): - shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') - shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') - shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in the Default or Notebook tab.') - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') - shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') - shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') + with gr.Row() as shared.gradio['grammar_file_row']: + shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Load grammar from file (.gbnf)', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button') + shared.gradio['save_grammar'] = gr.Button('💾', elem_classes='refresh-button') + shared.gradio['delete_grammar'] = gr.Button('🗑️ ', elem_classes='refresh-button') + + shared.gradio['grammar_string'] = gr.Textbox(value='', label='Grammar', lines=16, elem_classes=['add_scrollbar', 'monospace']) + shared.gradio['grammar_file'].change(lambda x: open(Path(f'grammars/{x}'), 'r').read(), gradio('grammar_file'), gradio('grammar_string')) ui_chat.create_chat_settings_ui()