From c238ba9532d0b504005f2557827291a993aad95a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 21 Apr 2023 17:18:34 -0300 Subject: [PATCH] Add a 'Count tokens' button --- server.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/server.py b/server.py index 5cbce5e0..c79d63ec 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,8 @@ import os -import requests import warnings +import requests + os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') @@ -15,10 +16,12 @@ def my_get(url, **kwargs): original_get = requests.get requests.get = my_get import gradio as gr + requests.get = original_get # This fixes LaTeX rendering on some systems import matplotlib + matplotlib.use('Agg') import importlib @@ -44,7 +47,8 @@ from modules import api, chat, shared, training, ui from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model -from modules.text_generation import generate_reply, stop_everything_event +from modules.text_generation import (encode, generate_reply, + stop_everything_event) def get_available_models(): @@ -172,6 +176,11 @@ def load_prompt(fname): return text +def count_tokens(text): + tokens = len(encode(text)[0]) + return f'{tokens} tokens in the input.' + + def download_model_wrapper(repo_id): try: downloader = importlib.import_module("download-model") @@ -628,6 +637,7 @@ def create_interface(): ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button') shared.gradio['save_prompt'] = gr.Button('Save prompt') + shared.gradio['count_tokens'] = gr.Button('Count tokens') shared.gradio['status'] = gr.Markdown('') with gr.Tab("Parameters", elem_id="parameters"): @@ -644,10 +654,11 @@ def create_interface(): shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input') shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) with gr.Row(): - shared.gradio['Generate'] = gr.Button('Generate', variant='primary') - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Continue'] = gr.Button('Continue') - shared.gradio['save_prompt'] = gr.Button('Save prompt') + shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button") + shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button") + shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button") + shared.gradio['save_prompt'] = gr.Button('Save prompt', elem_classes="small-button") + shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button") with gr.Row(): with gr.Column(): @@ -838,8 +849,9 @@ def create_interface(): ) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) - shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) - shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) + shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) + shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) + shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") # Launch the interface