Add a 'Count tokens' button

This commit is contained in:
oobabooga 2023-04-21 17:18:34 -03:00
parent a6ef2429fa
commit c238ba9532

View File

@ -1,7 +1,8 @@
import os import os
import requests
import warnings import warnings
import requests
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1' os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
@ -15,10 +16,12 @@ def my_get(url, **kwargs):
original_get = requests.get original_get = requests.get
requests.get = my_get requests.get = my_get
import gradio as gr import gradio as gr
requests.get = original_get requests.get = original_get
# This fixes LaTeX rendering on some systems # This fixes LaTeX rendering on some systems
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import importlib import importlib
@ -44,7 +47,8 @@ from modules import api, chat, shared, training, ui
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply, stop_everything_event from modules.text_generation import (encode, generate_reply,
stop_everything_event)
def get_available_models(): def get_available_models():
@ -172,6 +176,11 @@ def load_prompt(fname):
return text return text
def count_tokens(text):
tokens = len(encode(text)[0])
return f'{tokens} tokens in the input.'
def download_model_wrapper(repo_id): def download_model_wrapper(repo_id):
try: try:
downloader = importlib.import_module("download-model") downloader = importlib.import_module("download-model")
@ -628,6 +637,7 @@ def create_interface():
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button') ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
shared.gradio['save_prompt'] = gr.Button('Save prompt') shared.gradio['save_prompt'] = gr.Button('Save prompt')
shared.gradio['count_tokens'] = gr.Button('Count tokens')
shared.gradio['status'] = gr.Markdown('') shared.gradio['status'] = gr.Markdown('')
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
@ -644,10 +654,11 @@ def create_interface():
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input') shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
with gr.Row(): with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate', variant='primary') shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
shared.gradio['Stop'] = gr.Button('Stop') shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
shared.gradio['Continue'] = gr.Button('Continue') shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button")
shared.gradio['save_prompt'] = gr.Button('Save prompt') shared.gradio['save_prompt'] = gr.Button('Save prompt', elem_classes="small-button")
shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -838,8 +849,9 @@ def create_interface():
) )
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
# Launch the interface # Launch the interface