feature to save prompts with custom names (#1583)

---------

Co-authored-by: LoopLooter <looplooter>
Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
LoopLooter 2023-05-17 08:30:45 +03:00 committed by GitHub
parent c9c6aa2b6e
commit aeb1b7a9c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -122,11 +122,19 @@ def upload_soft_prompt(file):
return name return name
def save_prompt(text): def open_save_prompt():
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
def save_prompt(text, fname):
if fname != "":
with open(Path(f'prompts/{fname}.txt'), 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
return f"Saved to prompts/{fname}"
return f"Saved to prompts/{fname}.txt", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
else:
return "Error: No prompt name given.", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
def load_prompt(fname): def load_prompt(fname):
@ -657,7 +665,9 @@ def create_interface():
shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt') shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt')
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button') ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
shared.gradio['save_prompt'] = gr.Button('Save prompt') shared.gradio['open_save_prompt'] = gr.Button('Save prompt')
shared.gradio['save_prompt'] = gr.Button('Confirm save prompt', visible=False)
shared.gradio['prompt_to_save'] = gr.Textbox(elem_classes="textbox_default", lines=1, label='Prompt name:', interactive=True, visible=False)
shared.gradio['count_tokens'] = gr.Button('Count tokens') shared.gradio['count_tokens'] = gr.Button('Count tokens')
shared.gradio['status'] = gr.Markdown('') shared.gradio['status'] = gr.Markdown('')
@ -678,7 +688,8 @@ def create_interface():
shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button") shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button") shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button") shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button")
shared.gradio['save_prompt'] = gr.Button('Save prompt', elem_classes="small-button") shared.gradio['open_save_prompt'] = gr.Button('Save prompt', elem_classes="small-button")
shared.gradio['save_prompt'] = gr.Button('Confirm save prompt', visible=False, elem_classes="small-button")
shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button") shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button")
with gr.Row(): with gr.Row():
@ -688,6 +699,7 @@ def create_interface():
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button') ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
with gr.Column(): with gr.Column():
shared.gradio['prompt_to_save'] = gr.Textbox(elem_classes="textbox_default", lines=1, label='Prompt name:', interactive=True, visible=False)
shared.gradio['status'] = gr.Markdown('') shared.gradio['status'] = gr.Markdown('')
with gr.Column(): with gr.Column():
@ -871,7 +883,8 @@ def create_interface():
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['open_save_prompt'].click(open_save_prompt, None, [shared.gradio[k] for k in ['prompt_to_save', 'open_save_prompt', 'save_prompt']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio[k] for k in ['textbox', 'prompt_to_save']], [shared.gradio[k] for k in ['status', 'prompt_to_save', 'open_save_prompt', 'save_prompt']], show_progress=False)
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}")